mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #6866 from gnecula:tf_pjit
PiperOrigin-RevId: 376989780
This commit is contained in:
commit
8e6101c6a1
@ -11,6 +11,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
## jax 0.2.14 (unreleased)
|
## jax 0.2.14 (unreleased)
|
||||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master).
|
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master).
|
||||||
* New features:
|
* New features:
|
||||||
|
* The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`.
|
||||||
|
|
||||||
* Breaking changes:
|
* Breaking changes:
|
||||||
|
|
||||||
|
@ -475,8 +475,8 @@ in [savedmodel_test.py](https://github.com/google/jax/blob/master/jax/experiment
|
|||||||
|
|
||||||
### Missing converter features
|
### Missing converter features
|
||||||
|
|
||||||
There is currently no support for replicated (e.g. `pmap`) or multi-device
|
There is currently no support for `pmap` or`xmap`, nor for the collective
|
||||||
(e.g. `sharded_jit`) functions. The collective operations are not yet handled.
|
operations. There is support for `sharded_jit` and `pjit`.
|
||||||
|
|
||||||
### No SavedModel fine-tuning
|
### No SavedModel fine-tuning
|
||||||
|
|
||||||
|
@ -31,6 +31,8 @@ from jax._src.lax import lax
|
|||||||
from jax._src.lax import linalg as lax_linalg
|
from jax._src.lax import linalg as lax_linalg
|
||||||
import jax._src.random
|
import jax._src.random
|
||||||
from jax.api_util import flatten_fun
|
from jax.api_util import flatten_fun
|
||||||
|
from jax.experimental import maps
|
||||||
|
from jax.experimental import pjit
|
||||||
from jax.interpreters import ad
|
from jax.interpreters import ad
|
||||||
from jax.interpreters import pxla
|
from jax.interpreters import pxla
|
||||||
from jax.interpreters import sharded_jit
|
from jax.interpreters import sharded_jit
|
||||||
@ -397,9 +399,6 @@ def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args: TfVal) -> Sequence[TfVal]:
|
|||||||
return tuple(v for v, _ in out_with_avals)
|
return tuple(v for v, _ in out_with_avals)
|
||||||
|
|
||||||
|
|
||||||
### tracer
|
|
||||||
|
|
||||||
|
|
||||||
def _aval_to_tf_shape(aval: core.AbstractValue) -> Tuple[Optional[int], ...]:
|
def _aval_to_tf_shape(aval: core.AbstractValue) -> Tuple[Optional[int], ...]:
|
||||||
"""Generate a TF shape, possibly containing None for polymorphic dimensions."""
|
"""Generate a TF shape, possibly containing None for polymorphic dimensions."""
|
||||||
return tuple(
|
return tuple(
|
||||||
@ -729,8 +728,9 @@ class TensorFlowTrace(core.Trace):
|
|||||||
def process_call(self, call_primitive: core.Primitive, f: lu.WrappedFun,
|
def process_call(self, call_primitive: core.Primitive, f: lu.WrappedFun,
|
||||||
tracers: Sequence[TensorFlowTracer], params):
|
tracers: Sequence[TensorFlowTracer], params):
|
||||||
assert call_primitive.multiple_results
|
assert call_primitive.multiple_results
|
||||||
vals: Sequence[TfVal] = [t.val for t in tracers]
|
vals: Sequence[TfVal] = tuple(t.val for t in tracers)
|
||||||
f = _interpret_subtrace(f, self.main, tuple(t.aval for t in tracers))
|
avals: Sequence[core.AbstractValue] = tuple(t.aval for t in tracers)
|
||||||
|
f = _interpret_subtrace(f, self.main, avals)
|
||||||
with core.new_sublevel():
|
with core.new_sublevel():
|
||||||
if call_primitive == core.named_call_p:
|
if call_primitive == core.named_call_p:
|
||||||
with tf.name_scope(_sanitize_scope_name(params["name"])):
|
with tf.name_scope(_sanitize_scope_name(params["name"])):
|
||||||
@ -813,6 +813,8 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
for unexpected in xla.call_translations: # Call primitives are inlined
|
for unexpected in xla.call_translations: # Call primitives are inlined
|
||||||
|
if unexpected is pjit.pjit_p:
|
||||||
|
continue
|
||||||
tf_impl[unexpected] = functools.partial(_unexpected_primitive, unexpected)
|
tf_impl[unexpected] = functools.partial(_unexpected_primitive, unexpected)
|
||||||
|
|
||||||
# Primitives that are not yet implemented must be explicitly declared here.
|
# Primitives that are not yet implemented must be explicitly declared here.
|
||||||
@ -2494,8 +2496,8 @@ def split_to_logical_devices(tensor: TfVal,
|
|||||||
Returns:
|
Returns:
|
||||||
an annotated tensor.
|
an annotated tensor.
|
||||||
"""
|
"""
|
||||||
# This corresponds to the sharding annotations in
|
# TODO: this is only for sharded_jit. Either remove, or implement in terms
|
||||||
# xla_bridge._sharding_to_proto.
|
# of _shard_values.
|
||||||
if partition_dimensions is None:
|
if partition_dimensions is None:
|
||||||
return xla_sharding.replicate(tensor, use_sharding_op=True)
|
return xla_sharding.replicate(tensor, use_sharding_op=True)
|
||||||
num_partition_splits = np.prod(partition_dimensions)
|
num_partition_splits = np.prod(partition_dimensions)
|
||||||
@ -2504,6 +2506,24 @@ def split_to_logical_devices(tensor: TfVal,
|
|||||||
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
|
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _shard_value(mesh: maps.Mesh,
|
||||||
|
val: TfVal,
|
||||||
|
aval: core.AbstractValue,
|
||||||
|
axis_resources: pjit.ParsedPartitionSpec) -> TfVal:
|
||||||
|
"""Apply sharding to a TfVal."""
|
||||||
|
sharding_proto: xla_client.OpSharding = pjit.get_aval_sharding_proto(
|
||||||
|
aval, axis_resources, mesh)
|
||||||
|
# To use xla_sharding.py, we must have a xla_data_pb2.OpSharding.
|
||||||
|
xla_sharding_proto: xla_data_pb2.OpSharding = (
|
||||||
|
xla_data_pb2.OpSharding(
|
||||||
|
type=int(sharding_proto.type),
|
||||||
|
tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions,
|
||||||
|
tile_assignment_devices=sharding_proto.tile_assignment_devices,
|
||||||
|
replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim))
|
||||||
|
return xla_sharding.Sharding(proto=xla_sharding_proto).apply_to_tensor(
|
||||||
|
val, use_sharding_op=True)
|
||||||
|
|
||||||
|
|
||||||
def _sharded_call(f: lu.WrappedFun, vals: Sequence[TfVal],
|
def _sharded_call(f: lu.WrappedFun, vals: Sequence[TfVal],
|
||||||
in_parts: Sequence[pxla.PartitionsOrReplicated],
|
in_parts: Sequence[pxla.PartitionsOrReplicated],
|
||||||
out_parts_thunk,
|
out_parts_thunk,
|
||||||
@ -2520,12 +2540,51 @@ def _sharded_call(f: lu.WrappedFun, vals: Sequence[TfVal],
|
|||||||
return sharded_vals_out
|
return sharded_vals_out
|
||||||
|
|
||||||
|
|
||||||
def _sharding_constraint(arg: TfVal, *,
|
def _sharded_jit_sharding_constraint(arg: TfVal, *,
|
||||||
partitions: pxla.PartitionsOrReplicated):
|
partitions: pxla.PartitionsOrReplicated,
|
||||||
|
_in_avals: Sequence[core.ShapedArray],
|
||||||
|
_out_aval: core.ShapedArray):
|
||||||
|
del _in_avals, _out_aval
|
||||||
return split_to_logical_devices(arg, partitions)
|
return split_to_logical_devices(arg, partitions)
|
||||||
|
|
||||||
|
|
||||||
tf_impl[sharded_jit.sharding_constraint_p] = _sharding_constraint
|
tf_impl_with_avals[sharded_jit.sharding_constraint_p] = _sharded_jit_sharding_constraint
|
||||||
|
|
||||||
|
|
||||||
|
def _pjit(*args: TfVal,
|
||||||
|
jaxpr: core.ClosedJaxpr,
|
||||||
|
in_axis_resources: Sequence[pjit.ParsedPartitionSpec],
|
||||||
|
out_axis_resources: Sequence[pjit.ParsedPartitionSpec],
|
||||||
|
resource_env: maps.ResourceEnv,
|
||||||
|
donated_invars,
|
||||||
|
name: str,
|
||||||
|
_in_avals: Sequence[core.ShapedArray],
|
||||||
|
_out_aval: core.ShapedArray) -> TfVal:
|
||||||
|
del donated_invars, name
|
||||||
|
# TODO: add `name` to the name stack
|
||||||
|
shard_value_for_mesh = functools.partial(_shard_value, resource_env.physical_mesh)
|
||||||
|
# Apply sharding annotation to the arguments
|
||||||
|
sharded_args: Sequence[TfVal] = tuple(
|
||||||
|
util.safe_map(shard_value_for_mesh, args, _in_avals, in_axis_resources))
|
||||||
|
results = _interpret_jaxpr(jaxpr, *sharded_args)
|
||||||
|
sharded_results: Sequence[TfVal] = tuple(
|
||||||
|
util.safe_map(shard_value_for_mesh, results, _out_aval, out_axis_resources))
|
||||||
|
return tuple(sharded_results)
|
||||||
|
|
||||||
|
|
||||||
|
tf_impl_with_avals[pjit.pjit_p] = _pjit
|
||||||
|
|
||||||
|
|
||||||
|
def _pjit_sharding_constraint(arg: TfVal, *,
|
||||||
|
axis_resources: pjit.ParsedPartitionSpec,
|
||||||
|
resource_env: maps.ResourceEnv,
|
||||||
|
_in_avals: Sequence[core.ShapedArray],
|
||||||
|
_out_aval: core.ShapedArray,
|
||||||
|
**kwargs) -> TfVal:
|
||||||
|
return _shard_value(resource_env.physical_mesh, arg, _in_avals[0], axis_resources)
|
||||||
|
|
||||||
|
|
||||||
|
tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint
|
||||||
|
|
||||||
|
|
||||||
def _register_checkpoint_pytrees():
|
def _register_checkpoint_pytrees():
|
||||||
|
@ -13,9 +13,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for the jax2tf conversion of sharded_jit."""
|
"""Tests for the jax2tf conversion of sharded_jit."""
|
||||||
|
|
||||||
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Sequence
|
from typing import Any, Sequence
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
@ -25,6 +26,7 @@ from jax import test_util as jtu
|
|||||||
from jax.config import config
|
from jax.config import config
|
||||||
|
|
||||||
from jax.experimental import jax2tf
|
from jax.experimental import jax2tf
|
||||||
|
from jax.experimental import pjit
|
||||||
from jax.experimental.jax2tf.tests import tf_test_util
|
from jax.experimental.jax2tf.tests import tf_test_util
|
||||||
from jax.interpreters import sharded_jit
|
from jax.interpreters import sharded_jit
|
||||||
from jax.interpreters.sharded_jit import PartitionSpec as P
|
from jax.interpreters.sharded_jit import PartitionSpec as P
|
||||||
@ -36,6 +38,13 @@ import tensorflow as tf # type: ignore[import]
|
|||||||
|
|
||||||
config.parse_flags_with_absl()
|
config.parse_flags_with_absl()
|
||||||
|
|
||||||
|
def setUpModule():
|
||||||
|
jtu.set_spmd_lowering_flag(True)
|
||||||
|
|
||||||
|
def tearDownModule():
|
||||||
|
jtu.restore_spmd_lowering_flag()
|
||||||
|
|
||||||
|
|
||||||
LOG_HLO = True
|
LOG_HLO = True
|
||||||
|
|
||||||
class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||||
@ -46,7 +55,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
|||||||
|
|
||||||
def _check_sharding_annotations(self,
|
def _check_sharding_annotations(self,
|
||||||
f_jax,
|
f_jax,
|
||||||
args,
|
args: Sequence[Any],
|
||||||
*,
|
*,
|
||||||
expected: Sequence[str],
|
expected: Sequence[str],
|
||||||
expected_opt: Sequence[str],
|
expected_opt: Sequence[str],
|
||||||
@ -57,6 +66,9 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
|||||||
We currently check the unoptimized HLO against `expected` on CPU and TPU,
|
We currently check the unoptimized HLO against `expected` on CPU and TPU,
|
||||||
and we check the optimized HLO against `expected_opt` on TPU only and
|
and we check the optimized HLO against `expected_opt` on TPU only and
|
||||||
only for JAX.
|
only for JAX.
|
||||||
|
|
||||||
|
See `self.AssertShardingAnnotations` for documentation of `expected`
|
||||||
|
and `expected_opt`.
|
||||||
"""
|
"""
|
||||||
if jtu.device_under_test() == "gpu":
|
if jtu.device_under_test() == "gpu":
|
||||||
raise unittest.SkipTest("Sharding HLO tests not useful for GPU")
|
raise unittest.SkipTest("Sharding HLO tests not useful for GPU")
|
||||||
@ -93,20 +105,20 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
|||||||
experimental_get_compiler_ir(*args)(stage="hlo",
|
experimental_get_compiler_ir(*args)(stage="hlo",
|
||||||
device_name=device_name)
|
device_name=device_name)
|
||||||
if LOG_HLO:
|
if LOG_HLO:
|
||||||
logging.info(f"[{self._testMethodName}] got TF OPT HLO {tf_hlo}")
|
logging.info(f"[{self._testMethodName}] got TF HLO {tf_hlo}")
|
||||||
self.AssertShardingAnnotations("TF before optimizations", tf_hlo, expected)
|
self.AssertShardingAnnotations("TF before optimizations", tf_hlo, expected)
|
||||||
tf_optimized_hlo = tf.function(f_tf, jit_compile=True).\
|
tf_optimized_hlo = tf.function(f_tf, jit_compile=True).\
|
||||||
experimental_get_compiler_ir(*args)(stage="optimized_hlo",
|
experimental_get_compiler_ir(*args)(stage="optimized_hlo",
|
||||||
device_name=device_name)
|
device_name=device_name)
|
||||||
if LOG_HLO:
|
if LOG_HLO:
|
||||||
logging.info(f"[{self._testMethodName}] XXX got TF OPT HLO "
|
logging.info(f"[{self._testMethodName}] got TF optimized HLO "
|
||||||
f"for {device_name}: {tf_optimized_hlo}")
|
f"for {device_name}: {tf_optimized_hlo}")
|
||||||
|
|
||||||
def AssertShardingAnnotations(self, what: str, hlo: str,
|
def AssertShardingAnnotations(self, what: str, hlo: str,
|
||||||
expected: Sequence[str]):
|
expected: Sequence[str]):
|
||||||
"""Args:
|
"""Args:
|
||||||
|
|
||||||
what: either 'JAX' or 'TF'
|
what: either 'JAX' or 'TF', used for messages only.
|
||||||
hlo: the text for the HLO module
|
hlo: the text for the HLO module
|
||||||
expected: a sequence of regexps that must occur in the hlo text. Each
|
expected: a sequence of regexps that must occur in the hlo text. Each
|
||||||
regexp must match a line, in order.
|
regexp must match a line, in order.
|
||||||
@ -128,7 +140,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
|||||||
f"!!! Not found[{next_expected_idx}] {expected[next_expected_idx]}")
|
f"!!! Not found[{next_expected_idx}] {expected[next_expected_idx]}")
|
||||||
raise self.failureException("\n".join(failure_msg))
|
raise self.failureException("\n".join(failure_msg))
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_sharded_jit_in_out(self):
|
||||||
"""Test input and output sharding annotations."""
|
"""Test input and output sharding annotations."""
|
||||||
sharded_jax_func = sharded_jit.sharded_jit(
|
sharded_jax_func = sharded_jit.sharded_jit(
|
||||||
jnp.dot, in_parts=(P(1, 2), P(2, 1)), out_parts=P(1, 2))
|
jnp.dot, in_parts=(P(1, 2), P(2, 1)), out_parts=P(1, 2))
|
||||||
@ -152,7 +164,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
|||||||
],
|
],
|
||||||
num_partitions=2)
|
num_partitions=2)
|
||||||
|
|
||||||
def test_with_sharding_constraint(self):
|
def test_sharded_jit_with_sharding_constraint(self):
|
||||||
"""A sharding constraint in the middle."""
|
"""A sharding constraint in the middle."""
|
||||||
|
|
||||||
def jax_func(x, y):
|
def jax_func(x, y):
|
||||||
@ -180,7 +192,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
|||||||
],
|
],
|
||||||
num_partitions=2)
|
num_partitions=2)
|
||||||
|
|
||||||
def test_replicated(self):
|
def test_sharded_jit_replicated(self):
|
||||||
"""A replicated input and output."""
|
"""A replicated input and output."""
|
||||||
|
|
||||||
sharded_jax_func = sharded_jit.sharded_jit(
|
sharded_jax_func = sharded_jit.sharded_jit(
|
||||||
@ -203,6 +215,116 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
|||||||
],
|
],
|
||||||
num_partitions=2)
|
num_partitions=2)
|
||||||
|
|
||||||
|
@jtu.with_mesh([("x", 2)])
|
||||||
|
def test_pjit_basic1D(self):
|
||||||
|
|
||||||
|
@functools.partial(pjit.pjit,
|
||||||
|
in_axis_resources=(P("x"), P("x")),
|
||||||
|
out_axis_resources=None)
|
||||||
|
def jax_func(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
shape = (8, 10)
|
||||||
|
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||||
|
hlo = jax.xla_computation(jax_func)(x, x).as_hlo_text()
|
||||||
|
print(f"HLO is {hlo}")
|
||||||
|
print(f"JAXPR is {jax.make_jaxpr(jax_func)(x, x)}")
|
||||||
|
self._check_sharding_annotations(
|
||||||
|
jax_func, [x, x],
|
||||||
|
expected=[
|
||||||
|
r"f32\[8,10\].*sharding={devices=\[2,1\]", # x and y
|
||||||
|
r"f32\[8,10\].*sharding={replicated", # output
|
||||||
|
],
|
||||||
|
expected_opt=[
|
||||||
|
r"f32\[4,10\].*sharding={devices=\[2,1\]", # x and y
|
||||||
|
# TODO: why don't we see "sharding={replicated"
|
||||||
|
r"f32\[8,10\]", # output
|
||||||
|
],
|
||||||
|
num_partitions=2)
|
||||||
|
|
||||||
|
@jtu.with_mesh([("x", 2), ("y", 2)])
|
||||||
|
def test_pjit_basic2D(self):
|
||||||
|
@functools.partial(pjit.pjit,
|
||||||
|
in_axis_resources=(P(None, "x", "y"), P("y")),
|
||||||
|
out_axis_resources=P("x"))
|
||||||
|
def jax_func(x, y):
|
||||||
|
return x @ y
|
||||||
|
|
||||||
|
x_shape = (8, 6, 4)
|
||||||
|
y_shape = (4, 2)
|
||||||
|
x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
|
||||||
|
y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape)
|
||||||
|
self._check_sharding_annotations(
|
||||||
|
jax_func, [x, y],
|
||||||
|
expected=[
|
||||||
|
r"f32\[8,6,4\].*sharding={devices=\[1,2,2\]0,1,2,3", # x
|
||||||
|
r"f32\[4,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate", # y
|
||||||
|
r"f32\[8,6,2\].*sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate", # output
|
||||||
|
],
|
||||||
|
expected_opt=[
|
||||||
|
# TODO: relax ordering
|
||||||
|
r"f32\[2,2\].*sharding={devices=\[2,1,2\]0,2,1,3 last_tile_dim_replicate", # y
|
||||||
|
r"f32\[8,3,2\].*sharding={devices=\[1,2,2\]0,1,2,3", # x
|
||||||
|
# TODO: why we cannot see sharding={devices=\[2,1,1,2\]0,1,2,3 last_tile_dim_replicate?
|
||||||
|
r"bf16\[4,6,2\]", # output
|
||||||
|
],
|
||||||
|
num_partitions=4)
|
||||||
|
|
||||||
|
@jtu.with_mesh([("x", 2), ("y", 2)])
|
||||||
|
def test_pjit_TwoMeshAxisSharding(self):
|
||||||
|
@functools.partial(pjit.pjit,
|
||||||
|
in_axis_resources=P(("x", "y"),),
|
||||||
|
out_axis_resources=P(("x", "y"),))
|
||||||
|
def jax_func(x, y):
|
||||||
|
return x @ y
|
||||||
|
|
||||||
|
x_shape = (24, 8)
|
||||||
|
y_shape = (8, 2)
|
||||||
|
x = jnp.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
|
||||||
|
y = jnp.arange(np.prod(y_shape), dtype=np.float32).reshape(y_shape)
|
||||||
|
self._check_sharding_annotations(
|
||||||
|
jax_func, [x, y],
|
||||||
|
expected=[
|
||||||
|
r"f32\[24,8\].*sharding={devices=\[4,1\]0,1,2,3", # x
|
||||||
|
r"f32\[8,2\].*sharding={devices=\[4,1\]0,1,2,3", # y
|
||||||
|
r"f32\[24,2\].*sharding={devices=\[4,1\]0,1,2,3", # output
|
||||||
|
],
|
||||||
|
expected_opt=[
|
||||||
|
# TODO: relax ordering
|
||||||
|
r"f32\[2,2\].*sharding={devices=\[4,1\]0,1,2,3", # y
|
||||||
|
r"f32\[6,8\].*sharding={devices=\[4,1\]0,1,2,3", # x
|
||||||
|
# TODO: why we cannot see .*sharding={devices=\[4,1\]0,1,2,3
|
||||||
|
r"f32\[1,6,2\]", # output
|
||||||
|
],
|
||||||
|
num_partitions=4)
|
||||||
|
|
||||||
|
@jtu.with_mesh([("x", 2), ("y", 1)])
|
||||||
|
def test_pjit_ShardingConstraint(self):
|
||||||
|
@functools.partial(pjit.pjit, in_axis_resources=None,
|
||||||
|
out_axis_resources=None)
|
||||||
|
def jax_func(x): # x: f32[12, 8]
|
||||||
|
y = jnp.tile(x, (2, 1)) # y: f32[24, 8]
|
||||||
|
y = pjit.with_sharding_constraint(y, P("x", "y"))
|
||||||
|
return y[0:y.shape[0] // 4] # res: f32[6, 8]
|
||||||
|
|
||||||
|
shape = (12, 8)
|
||||||
|
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||||
|
self._check_sharding_annotations(
|
||||||
|
jax_func, [x],
|
||||||
|
expected=[
|
||||||
|
r"f32\[12,8\].*sharding={replicated}", # x
|
||||||
|
r"f32\[24,8\].*sharding={devices=\[2,1\]0,1", # y
|
||||||
|
r"f32\[6,8\].*sharding={replicated}", # output
|
||||||
|
],
|
||||||
|
expected_opt=[
|
||||||
|
r"f32\[12,8\].*sharding={replicated}", # x
|
||||||
|
# TODO: why can't we see "sharding={devices=\[2,1\]0,1"
|
||||||
|
r"f32\[1,12,8\]", # y
|
||||||
|
# TODO: why can't we see "sharding={replicated}" ?
|
||||||
|
r"f32\[6,8\]", # output
|
||||||
|
],
|
||||||
|
num_partitions=2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||||
|
@ -590,13 +590,20 @@ def get_array_mapping(axis_resources: ParsedPartitionSpec) -> pxla.ArrayMapping:
|
|||||||
for i, axes in enumerate(axis_resources)
|
for i, axes in enumerate(axis_resources)
|
||||||
for axis in axes)
|
for axis in axes)
|
||||||
|
|
||||||
def get_sharding_proto(c, xla_op, axis_resources, mesh):
|
def get_sharding_proto(c, xla_op, axis_resources: ParsedPartitionSpec,
|
||||||
|
mesh: maps.Mesh) -> xc.OpSharding:
|
||||||
xla_shape = c.GetShape(xla_op)
|
xla_shape = c.GetShape(xla_op)
|
||||||
if xla_shape.is_token():
|
if xla_shape.is_token():
|
||||||
aval = core.abstract_token
|
aval = core.abstract_token
|
||||||
assert axis_resources is REPLICATED
|
assert axis_resources is REPLICATED
|
||||||
else:
|
else:
|
||||||
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.element_type())
|
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.element_type())
|
||||||
|
return get_aval_sharding_proto(aval, axis_resources, mesh)
|
||||||
|
|
||||||
|
|
||||||
|
def get_aval_sharding_proto(aval: core.AbstractValue,
|
||||||
|
axis_resources: ParsedPartitionSpec,
|
||||||
|
mesh: maps.Mesh) -> xc.OpSharding:
|
||||||
array_mapping = get_array_mapping(axis_resources)
|
array_mapping = get_array_mapping(axis_resources)
|
||||||
sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(
|
sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(
|
||||||
aval, array_mapping)
|
aval, array_mapping)
|
||||||
|
@ -17,7 +17,7 @@ import functools
|
|||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Dict, Sequence, Union
|
from typing import Dict, List, Generator, Sequence, Tuple, Union
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
import zlib
|
import zlib
|
||||||
@ -33,10 +33,12 @@ from . import core
|
|||||||
from ._src import dtypes as _dtypes
|
from ._src import dtypes as _dtypes
|
||||||
from . import lax
|
from . import lax
|
||||||
from ._src.config import flags, bool_env, config
|
from ._src.config import flags, bool_env, config
|
||||||
from ._src.util import partial, prod
|
from ._src.util import partial, prod, unzip2
|
||||||
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
|
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
|
||||||
from .lib import xla_bridge
|
from .lib import xla_bridge
|
||||||
from .interpreters import xla
|
from .interpreters import xla
|
||||||
|
from .experimental import maps
|
||||||
|
from .experimental.maps import mesh
|
||||||
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
@ -1012,6 +1014,44 @@ def ignore_warning(**kw):
|
|||||||
warnings.filterwarnings("ignore", **kw)
|
warnings.filterwarnings("ignore", **kw)
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
# -------------------- Mesh parametrization helpers --------------------
|
||||||
|
|
||||||
|
MeshSpec = List[Tuple[str, int]]
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
|
||||||
|
"""Test utility for setting up meshes given mesh data from `schedules`."""
|
||||||
|
# This is similar to the `with_mesh` function above, but isn't a decorator.
|
||||||
|
axis_names, shape = unzip2(named_shape)
|
||||||
|
size = prod(shape)
|
||||||
|
local_devices = list(api.local_devices())
|
||||||
|
if len(local_devices) < size:
|
||||||
|
raise unittest.SkipTest(f"Test requires {size} local devices")
|
||||||
|
mesh_devices = np.array(local_devices[:size]).reshape(shape)
|
||||||
|
with mesh(mesh_devices, axis_names):
|
||||||
|
yield
|
||||||
|
|
||||||
|
def with_mesh_from_kwargs(f):
|
||||||
|
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
|
||||||
|
|
||||||
|
def with_and_without_mesh(f):
|
||||||
|
return parameterized.named_parameters(
|
||||||
|
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
||||||
|
for name, mesh, axis_resources in (
|
||||||
|
('', (), ()),
|
||||||
|
('Mesh', (('x', 2),), (('i', 'x'),))
|
||||||
|
))(with_mesh_from_kwargs(f))
|
||||||
|
|
||||||
|
old_spmd_lowering_flag = False
|
||||||
|
def set_spmd_lowering_flag(val: bool):
|
||||||
|
global old_spmd_lowering_flag
|
||||||
|
maps.make_xmap_callable.cache_clear()
|
||||||
|
old_spmd_lowering_flag = maps.EXPERIMENTAL_SPMD_LOWERING
|
||||||
|
maps.EXPERIMENTAL_SPMD_LOWERING = val
|
||||||
|
|
||||||
|
def restore_spmd_lowering_flag():
|
||||||
|
maps.make_xmap_callable.cache_clear()
|
||||||
|
maps.EXPERIMENTAL_SPMD_LOWERING = old_spmd_lowering_flag
|
||||||
|
|
||||||
class _cached_property:
|
class _cached_property:
|
||||||
null = object()
|
null = object()
|
||||||
|
@ -12,10 +12,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import Generator, List, Tuple
|
|
||||||
from unittest import SkipTest
|
from unittest import SkipTest
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
@ -32,47 +30,23 @@ from jax.experimental import PartitionSpec as P
|
|||||||
from jax.experimental.maps import xmap, mesh
|
from jax.experimental.maps import xmap, mesh
|
||||||
from jax.experimental.pjit import pjit, pjit_p, with_sharding_constraint, SpecSync
|
from jax.experimental.pjit import pjit, pjit_p, with_sharding_constraint, SpecSync
|
||||||
from jax.interpreters import pxla
|
from jax.interpreters import pxla
|
||||||
from jax._src.util import unzip2, prod, curry
|
from jax._src.util import prod, curry
|
||||||
|
|
||||||
from jax.config import config
|
from jax.config import config
|
||||||
config.parse_flags_with_absl()
|
config.parse_flags_with_absl()
|
||||||
|
|
||||||
|
|
||||||
def setUpModule():
|
def setUpModule():
|
||||||
global old_lowering_flag
|
jtu.set_spmd_lowering_flag(True)
|
||||||
jax.experimental.maps.make_xmap_callable.cache_clear()
|
|
||||||
old_lowering_flag = jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING
|
|
||||||
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
|
|
||||||
|
|
||||||
def tearDownModule():
|
def tearDownModule():
|
||||||
jax.experimental.maps.make_xmap_callable.cache_clear()
|
jtu.restore_spmd_lowering_flag()
|
||||||
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = old_lowering_flag
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(skye): move into test_util and dedup with xmap_test.py
|
|
||||||
MeshSpec = List[Tuple[str, int]]
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
|
|
||||||
"""Test utility for setting up meshes given mesh data from `schedules`."""
|
|
||||||
# This is similar to the `with_mesh` function above, but isn't a decorator.
|
|
||||||
axis_names, shape = unzip2(named_shape)
|
|
||||||
size = prod(shape)
|
|
||||||
local_devices = list(jax.local_devices())
|
|
||||||
if len(local_devices) < size:
|
|
||||||
raise SkipTest(f"Test requires {size} local devices")
|
|
||||||
mesh_devices = np.array(local_devices[:size]).reshape(shape)
|
|
||||||
with mesh(mesh_devices, axis_names):
|
|
||||||
yield
|
|
||||||
|
|
||||||
def with_mesh_from_kwargs(f):
|
|
||||||
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(skye): make the buffer donation utils part of JaxTestCase
|
# TODO(skye): make the buffer donation utils part of JaxTestCase
|
||||||
class PJitTest(jtu.BufferDonationTestCase):
|
class PJitTest(jtu.BufferDonationTestCase):
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testBasic1D(self):
|
def testBasic1D(self):
|
||||||
@partial(pjit,
|
@partial(pjit,
|
||||||
in_axis_resources=(P('x'), P('x')),
|
in_axis_resources=(P('x'), P('x')),
|
||||||
@ -90,7 +64,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
self.assertAllClose(actual.device_buffers[0].to_py(), expected,
|
self.assertAllClose(actual.device_buffers[0].to_py(), expected,
|
||||||
check_dtypes=False)
|
check_dtypes=False)
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 2)])
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||||
def testBasic2D(self):
|
def testBasic2D(self):
|
||||||
@partial(pjit,
|
@partial(pjit,
|
||||||
in_axis_resources=(P(None, 'x', 'y'), P('y')),
|
in_axis_resources=(P(None, 'x', 'y'), P('y')),
|
||||||
@ -118,7 +92,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
self.assertAllClose(actual.device_buffers[3].to_py(), split1,
|
self.assertAllClose(actual.device_buffers[3].to_py(), split1,
|
||||||
check_dtypes=False)
|
check_dtypes=False)
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 2)])
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||||
def testTwoMeshAxisSharding(self):
|
def testTwoMeshAxisSharding(self):
|
||||||
@partial(pjit,
|
@partial(pjit,
|
||||||
in_axis_resources=P(('x', 'y'),),
|
in_axis_resources=P(('x', 'y'),),
|
||||||
@ -144,7 +118,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
self.assertAllClose(actual.device_buffers[3].to_py(), splits[3],
|
self.assertAllClose(actual.device_buffers[3].to_py(), splits[3],
|
||||||
check_dtypes=False)
|
check_dtypes=False)
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testBufferDonation(self):
|
def testBufferDonation(self):
|
||||||
@partial(pjit,
|
@partial(pjit,
|
||||||
in_axis_resources=P('x'),
|
in_axis_resources=P('x'),
|
||||||
@ -162,7 +136,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
self.assertNotDeleted(y)
|
self.assertNotDeleted(y)
|
||||||
self.assertDeleted(x)
|
self.assertDeleted(x)
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 1)])
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||||
def testShardingConstraint(self):
|
def testShardingConstraint(self):
|
||||||
@partial(pjit, in_axis_resources=None, out_axis_resources=None)
|
@partial(pjit, in_axis_resources=None, out_axis_resources=None)
|
||||||
def f(x):
|
def f(x):
|
||||||
@ -186,7 +160,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
# Annotation from pjit
|
# Annotation from pjit
|
||||||
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
|
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 1)])
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||||
def testShardingConstraintPyTree(self):
|
def testShardingConstraintPyTree(self):
|
||||||
@partial(pjit, in_axis_resources=None, out_axis_resources=None)
|
@partial(pjit, in_axis_resources=None, out_axis_resources=None)
|
||||||
def f(x):
|
def f(x):
|
||||||
@ -232,7 +206,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
should_be_tracing = False
|
should_be_tracing = False
|
||||||
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
|
pjit(f, in_axis_resources=P(('x', 'y')), out_axis_resources=None)(x)
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 1)])
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||||
def testNested(self):
|
def testNested(self):
|
||||||
# Add a constant captured by the nested pjit to make things more complicated
|
# Add a constant captured by the nested pjit to make things more complicated
|
||||||
h = jnp.arange(4)
|
h = jnp.arange(4)
|
||||||
@ -243,7 +217,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
self.assertAllClose(y, jnp.sin(x).sum() + h.sum())
|
self.assertAllClose(y, jnp.sin(x).sum() + h.sum())
|
||||||
self.assertTrue(hasattr(y, "sharding_spec"))
|
self.assertTrue(hasattr(y, "sharding_spec"))
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 1)])
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||||
def testJVP(self):
|
def testJVP(self):
|
||||||
# Add a constant captured by the nested pjit to make things more complicated
|
# Add a constant captured by the nested pjit to make things more complicated
|
||||||
h = jnp.arange(4)
|
h = jnp.arange(4)
|
||||||
@ -252,7 +226,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
jtu.check_grads(g, (jnp.arange(16, dtype=jnp.float32).reshape((4, 4)),),
|
jtu.check_grads(g, (jnp.arange(16, dtype=jnp.float32).reshape((4, 4)),),
|
||||||
order=2, modes=["fwd"], eps=1)
|
order=2, modes=["fwd"], eps=1)
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 1)])
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||||
def testEvalJaxpr(self):
|
def testEvalJaxpr(self):
|
||||||
x, y = jnp.arange(4), jnp.arange(5)
|
x, y = jnp.arange(4), jnp.arange(5)
|
||||||
f = pjit(lambda x, y: x.sum() + jnp.sin(y),
|
f = pjit(lambda x, y: x.sum() + jnp.sin(y),
|
||||||
@ -263,13 +237,13 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
r, = f_eval(x, y)
|
r, = f_eval(x, y)
|
||||||
self.assertAllClose(r, x.sum() + jnp.sin(y))
|
self.assertAllClose(r, x.sum() + jnp.sin(y))
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testNonArrayArg(self):
|
def testNonArrayArg(self):
|
||||||
self.assertEqual(pjit(lambda x: x + 2,
|
self.assertEqual(pjit(lambda x: x + 2,
|
||||||
in_axis_resources=None,
|
in_axis_resources=None,
|
||||||
out_axis_resources=None)(1), 3)
|
out_axis_resources=None)(1), 3)
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testNonHashableAxisResources(self):
|
def testNonHashableAxisResources(self):
|
||||||
x = jnp.arange(4)
|
x = jnp.arange(4)
|
||||||
y = pjit(lambda x: {'b': x['a'] + 2},
|
y = pjit(lambda x: {'b': x['a'] + 2},
|
||||||
@ -277,7 +251,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
out_axis_resources={'b': P('x')})({'a': x})
|
out_axis_resources={'b': P('x')})({'a': x})
|
||||||
self.assertAllClose(y, {'b': x + 2})
|
self.assertAllClose(y, {'b': x + 2})
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testGradOfConstraint(self):
|
def testGradOfConstraint(self):
|
||||||
# Make sure that we can compute grads through sharding constraints
|
# Make sure that we can compute grads through sharding constraints
|
||||||
h = lambda x: jnp.sin(with_sharding_constraint(x, P('x'))).sum()
|
h = lambda x: jnp.sin(with_sharding_constraint(x, P('x'))).sum()
|
||||||
@ -286,7 +260,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
x = jnp.arange(8, dtype=jnp.float32)
|
x = jnp.arange(8, dtype=jnp.float32)
|
||||||
self.assertAllClose(f(x), jnp.cos(x))
|
self.assertAllClose(f(x), jnp.cos(x))
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testNoopPartitionSpecs(self):
|
def testNoopPartitionSpecs(self):
|
||||||
noops = [P(), P(None), P(()), P((), None), P(None, None, ())]
|
noops = [P(), P(None), P(()), P((), None), P(None, None, ())]
|
||||||
x = jnp.arange(8).reshape((2, 2, 2))
|
x = jnp.arange(8).reshape((2, 2, 2))
|
||||||
@ -294,7 +268,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
y = pjit(lambda x: x * 2, in_axis_resources=spec, out_axis_resources=spec)(x)
|
y = pjit(lambda x: x * 2, in_axis_resources=spec, out_axis_resources=spec)(x)
|
||||||
self.assertAllClose(y, x * 2)
|
self.assertAllClose(y, x * 2)
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testVmapModifiesAxisResources(self):
|
def testVmapModifiesAxisResources(self):
|
||||||
h = pjit(lambda x, y: (x + y, x, y), in_axis_resources=P('x'), out_axis_resources=None)
|
h = pjit(lambda x, y: (x + y, x, y), in_axis_resources=P('x'), out_axis_resources=None)
|
||||||
x = jnp.arange(4)
|
x = jnp.arange(4)
|
||||||
@ -310,7 +284,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
self.assertEqual(y_sync, SpecSync.IN_SYNC)
|
self.assertEqual(y_sync, SpecSync.IN_SYNC)
|
||||||
self.assertEqual(z_sync, SpecSync.DIM_PERMUTE)
|
self.assertEqual(z_sync, SpecSync.DIM_PERMUTE)
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testVMap(self):
|
def testVMap(self):
|
||||||
f = pjit(lambda x, y: (x + y, x), in_axis_resources=P('x'), out_axis_resources=P('x'))
|
f = pjit(lambda x, y: (x + y, x), in_axis_resources=P('x'), out_axis_resources=P('x'))
|
||||||
x = jnp.arange(4)
|
x = jnp.arange(4)
|
||||||
@ -402,7 +376,7 @@ def check_1d_2d_mesh(f, set_mesh):
|
|||||||
("2", (("x", 2),), "x"),
|
("2", (("x", 2),), "x"),
|
||||||
("2x1", (("x", 2), ("y", 1)), ("x", "y")),
|
("2x1", (("x", 2), ("y", 1)), ("x", "y")),
|
||||||
("2x2", (("x", 2), ("y", 2)), ("x", "y")),
|
("2x2", (("x", 2), ("y", 2)), ("x", "y")),
|
||||||
))(with_mesh_from_kwargs(f) if set_mesh else f)
|
))(jtu.with_mesh_from_kwargs(f) if set_mesh else f)
|
||||||
|
|
||||||
def spec_regex(s):
|
def spec_regex(s):
|
||||||
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
||||||
@ -444,7 +418,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
|||||||
in_axis_resources=None, out_axis_resources=None)(x)
|
in_axis_resources=None, out_axis_resources=None)(x)
|
||||||
|
|
||||||
@check_1d_2d_mesh(set_mesh=False)
|
@check_1d_2d_mesh(set_mesh=False)
|
||||||
@with_mesh([('z', 1)])
|
@jtu.with_mesh([('z', 1)])
|
||||||
def testUndefinedResourcesArgs(self, mesh, resources):
|
def testUndefinedResourcesArgs(self, mesh, resources):
|
||||||
x = jnp.ones((2, 2))
|
x = jnp.ones((2, 2))
|
||||||
spec = P(resources,)
|
spec = P(resources,)
|
||||||
@ -454,7 +428,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
|||||||
pjit(lambda x: x, in_axis_resources=spec, out_axis_resources=None)(x)
|
pjit(lambda x: x, in_axis_resources=spec, out_axis_resources=None)(x)
|
||||||
|
|
||||||
@check_1d_2d_mesh(set_mesh=False)
|
@check_1d_2d_mesh(set_mesh=False)
|
||||||
@with_mesh([('z', 1)])
|
@jtu.with_mesh([('z', 1)])
|
||||||
def testUndefinedResourcesOuts(self, mesh, resources):
|
def testUndefinedResourcesOuts(self, mesh, resources):
|
||||||
x = jnp.ones((2, 2))
|
x = jnp.ones((2, 2))
|
||||||
spec = P(resources,)
|
spec = P(resources,)
|
||||||
@ -464,7 +438,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
|||||||
pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x)
|
pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x)
|
||||||
|
|
||||||
@check_1d_2d_mesh(set_mesh=False)
|
@check_1d_2d_mesh(set_mesh=False)
|
||||||
@with_mesh([('z', 1)])
|
@jtu.with_mesh([('z', 1)])
|
||||||
def testUndefinedResourcesConstraint(self, mesh, resources):
|
def testUndefinedResourcesConstraint(self, mesh, resources):
|
||||||
x = jnp.ones((2, 2))
|
x = jnp.ones((2, 2))
|
||||||
spec = P(resources,)
|
spec = P(resources,)
|
||||||
@ -475,7 +449,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
|||||||
pjit(lambda x: with_sharding_constraint(x, spec),
|
pjit(lambda x: with_sharding_constraint(x, spec),
|
||||||
in_axis_resources=None, out_axis_resources=None)(x)
|
in_axis_resources=None, out_axis_resources=None)(x)
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 1)])
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||||
def testRankTooLowArgs(self):
|
def testRankTooLowArgs(self):
|
||||||
x = jnp.arange(2)
|
x = jnp.arange(2)
|
||||||
spec = P('x', 'y')
|
spec = P('x', 'y')
|
||||||
@ -484,7 +458,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(ValueError, error):
|
with self.assertRaisesRegex(ValueError, error):
|
||||||
pjit(lambda x: x.sum(), in_axis_resources=spec, out_axis_resources=None)(x)
|
pjit(lambda x: x.sum(), in_axis_resources=spec, out_axis_resources=None)(x)
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 1)])
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||||
def testRankTooLowOuts(self):
|
def testRankTooLowOuts(self):
|
||||||
x = jnp.arange(2)
|
x = jnp.arange(2)
|
||||||
spec = P('x', 'y')
|
spec = P('x', 'y')
|
||||||
@ -493,7 +467,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(ValueError, error):
|
with self.assertRaisesRegex(ValueError, error):
|
||||||
pjit(lambda x: x.sum(), in_axis_resources=None, out_axis_resources=spec)(x)
|
pjit(lambda x: x.sum(), in_axis_resources=None, out_axis_resources=spec)(x)
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 1)])
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||||
def testRankTooLowConstraint(self):
|
def testRankTooLowConstraint(self):
|
||||||
x = jnp.arange(2)
|
x = jnp.arange(2)
|
||||||
spec = P('x', 'y')
|
spec = P('x', 'y')
|
||||||
@ -504,7 +478,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
|||||||
pjit(lambda x: with_sharding_constraint(x, spec),
|
pjit(lambda x: with_sharding_constraint(x, spec),
|
||||||
in_axis_resources=None, out_axis_resources=None)(x)
|
in_axis_resources=None, out_axis_resources=None)(x)
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 1)])
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||||
def testRepeatedInResources(self):
|
def testRepeatedInResources(self):
|
||||||
x = jnp.arange(2)
|
x = jnp.arange(2)
|
||||||
for spec in [P('x', 'x'), P('x', ('y', 'x'))]:
|
for spec in [P('x', 'x'), P('x', ('y', 'x'))]:
|
||||||
@ -514,7 +488,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(ValueError, error):
|
with self.assertRaisesRegex(ValueError, error):
|
||||||
pjit(lambda x: x, in_axis_resources=spec, out_axis_resources=None)(x)
|
pjit(lambda x: x, in_axis_resources=spec, out_axis_resources=None)(x)
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 1)])
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||||
def testRepeatedOutResources(self):
|
def testRepeatedOutResources(self):
|
||||||
x = jnp.arange(2)
|
x = jnp.arange(2)
|
||||||
for spec in [P('x', 'x'), P('x', ('y', 'x'))]:
|
for spec in [P('x', 'x'), P('x', ('y', 'x'))]:
|
||||||
@ -524,7 +498,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(ValueError, error):
|
with self.assertRaisesRegex(ValueError, error):
|
||||||
pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x)
|
pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x)
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testInputShardsXMapAxis(self):
|
def testInputShardsXMapAxis(self):
|
||||||
spec = P('x')
|
spec = P('x')
|
||||||
f = xmap(pjit(lambda x: x + 2, in_axis_resources=spec, out_axis_resources=None),
|
f = xmap(pjit(lambda x: x + 2, in_axis_resources=spec, out_axis_resources=None),
|
||||||
@ -537,7 +511,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(JAXTypeError, error):
|
with self.assertRaisesRegex(JAXTypeError, error):
|
||||||
f(x)
|
f(x)
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testOutputShardsXMapAxis(self):
|
def testOutputShardsXMapAxis(self):
|
||||||
spec = P('x')
|
spec = P('x')
|
||||||
f = xmap(pjit(lambda x: x + 2, in_axis_resources=None, out_axis_resources=spec),
|
f = xmap(pjit(lambda x: x + 2, in_axis_resources=None, out_axis_resources=spec),
|
||||||
@ -550,7 +524,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(JAXTypeError, error):
|
with self.assertRaisesRegex(JAXTypeError, error):
|
||||||
f(x)
|
f(x)
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testConstraintShardsXMapAxis(self):
|
def testConstraintShardsXMapAxis(self):
|
||||||
spec = P('x')
|
spec = P('x')
|
||||||
f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec),
|
f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec),
|
||||||
@ -563,7 +537,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(JAXTypeError, error):
|
with self.assertRaisesRegex(JAXTypeError, error):
|
||||||
f(x)
|
f(x)
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testCatchesInnerXMapErrors(self):
|
def testCatchesInnerXMapErrors(self):
|
||||||
f = pjit(xmap(lambda x, y: x, in_axes=(['i'], ['j']), out_axes=['i', 'j'],
|
f = pjit(xmap(lambda x, y: x, in_axes=(['i'], ['j']), out_axes=['i', 'j'],
|
||||||
axis_resources={'i': 'x', 'j': 'x'}),
|
axis_resources={'i': 'x', 'j': 'x'}),
|
||||||
|
@ -71,34 +71,6 @@ def tearDownModule():
|
|||||||
os.environ["XLA_FLAGS"] = prev_xla_flags
|
os.environ["XLA_FLAGS"] = prev_xla_flags
|
||||||
xla_bridge.get_backend.cache_clear()
|
xla_bridge.get_backend.cache_clear()
|
||||||
|
|
||||||
# -------------------- Mesh parametrization helpers --------------------
|
|
||||||
|
|
||||||
MeshSpec = List[Tuple[str, int]]
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
|
|
||||||
"""Test utility for setting up meshes given mesh data from `schedules`."""
|
|
||||||
# This is similar to the `with_mesh` function above, but isn't a decorator.
|
|
||||||
axis_names, shape = unzip2(named_shape)
|
|
||||||
size = prod(shape)
|
|
||||||
local_devices = list(jax.local_devices())
|
|
||||||
if len(local_devices) < size:
|
|
||||||
raise SkipTest(f"Test requires {size} local devices")
|
|
||||||
mesh_devices = np.array(local_devices[:size]).reshape(shape)
|
|
||||||
with mesh(mesh_devices, axis_names):
|
|
||||||
yield
|
|
||||||
|
|
||||||
def with_mesh_from_kwargs(f):
|
|
||||||
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
|
|
||||||
|
|
||||||
def with_and_without_mesh(f):
|
|
||||||
return parameterized.named_parameters(
|
|
||||||
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
|
||||||
for name, mesh, axis_resources in (
|
|
||||||
('', (), ()),
|
|
||||||
('Mesh', (('x', 2),), (('i', 'x'),))
|
|
||||||
))(with_mesh_from_kwargs(f))
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------- Itertools helpers --------------------
|
# -------------------- Itertools helpers --------------------
|
||||||
|
|
||||||
@ -135,7 +107,7 @@ def ensure_bdim(x, axis_name, bdim):
|
|||||||
AxisResources = Dict[str, Union[str, Tuple[str, ...]]]
|
AxisResources = Dict[str, Union[str, Tuple[str, ...]]]
|
||||||
|
|
||||||
def schedules(sizes: Dict[str, int]
|
def schedules(sizes: Dict[str, int]
|
||||||
) -> Generator[Tuple[AxisResources, MeshSpec], None, None]:
|
) -> Generator[Tuple[AxisResources, jtu.MeshSpec], None, None]:
|
||||||
"""Test utility generating xmap parallel schedules from logical names & sizes.
|
"""Test utility generating xmap parallel schedules from logical names & sizes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -275,7 +247,7 @@ class XMapTest(XMapTestCase):
|
|||||||
self.assertAllClose(c, a * 2)
|
self.assertAllClose(c, a * 2)
|
||||||
self.assertAllClose(d, b * 4)
|
self.assertAllClose(d, b * 4)
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 2)])
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||||
def testCollectiveReduce(self):
|
def testCollectiveReduce(self):
|
||||||
fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4),
|
fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4),
|
||||||
in_axes=[['a', 'b', ...], {0: 'c'}],
|
in_axes=[['a', 'b', ...], {0: 'c'}],
|
||||||
@ -289,7 +261,7 @@ class XMapTest(XMapTestCase):
|
|||||||
self.assertAllClose(c, (a * 2).sum(0))
|
self.assertAllClose(c, (a * 2).sum(0))
|
||||||
self.assertAllClose(d, b * 4)
|
self.assertAllClose(d, b * 4)
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 2)])
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||||
def testCollectivePermute2D(self):
|
def testCollectivePermute2D(self):
|
||||||
perm = np.array([3, 1, 2, 0])
|
perm = np.array([3, 1, 2, 0])
|
||||||
x = jnp.arange(4).reshape((2, 2))
|
x = jnp.arange(4).reshape((2, 2))
|
||||||
@ -313,7 +285,7 @@ class XMapTest(XMapTestCase):
|
|||||||
in_axes=['i', ...], out_axes=['i', ...])(x)
|
in_axes=['i', ...], out_axes=['i', ...])(x)
|
||||||
self.assertAllClose(result, x + x[jnp.newaxis].T)
|
self.assertAllClose(result, x + x[jnp.newaxis].T)
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 2)])
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||||
def testOneLogicalTwoMeshAxesBasic(self):
|
def testOneLogicalTwoMeshAxesBasic(self):
|
||||||
def f(v):
|
def f(v):
|
||||||
return lax.psum(v * 2, 'a'), v * 4
|
return lax.psum(v * 2, 'a'), v * 4
|
||||||
@ -325,7 +297,7 @@ class XMapTest(XMapTestCase):
|
|||||||
self.assertAllClose(ans, (v * 2).sum(0))
|
self.assertAllClose(ans, (v * 2).sum(0))
|
||||||
self.assertAllClose(ans2, v.T * 4)
|
self.assertAllClose(ans2, v.T * 4)
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 2)])
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||||
def testOneLogicalTwoMeshAxesSharding(self):
|
def testOneLogicalTwoMeshAxesSharding(self):
|
||||||
def f(v):
|
def f(v):
|
||||||
return v * 4
|
return v * 4
|
||||||
@ -346,7 +318,7 @@ class XMapTest(XMapTestCase):
|
|||||||
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
|
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
|
||||||
(pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
|
(pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 2)])
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||||
def testSkipFirstMeshDim(self):
|
def testSkipFirstMeshDim(self):
|
||||||
def run(axis_resources):
|
def run(axis_resources):
|
||||||
return xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...],
|
return xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...],
|
||||||
@ -379,7 +351,7 @@ class XMapTest(XMapTestCase):
|
|||||||
('OneToOne', (('x', 2), ('y', 2)), (('a', 'y'), ('b', 'x'))),
|
('OneToOne', (('x', 2), ('y', 2)), (('a', 'y'), ('b', 'x'))),
|
||||||
('Multiple', (('x', 2), ('y', 2), ('z', 2)), (('a', 'y'), ('b', ('x', 'z')))),
|
('Multiple', (('x', 2), ('y', 2), ('z', 2)), (('a', 'y'), ('b', ('x', 'z')))),
|
||||||
))
|
))
|
||||||
@with_mesh_from_kwargs
|
@jtu.with_mesh_from_kwargs
|
||||||
def testNestedMesh(self, mesh, axis_resources):
|
def testNestedMesh(self, mesh, axis_resources):
|
||||||
@partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}),
|
@partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}),
|
||||||
axis_resources=dict([axis_resources[0]]))
|
axis_resources=dict([axis_resources[0]]))
|
||||||
@ -405,7 +377,7 @@ class XMapTest(XMapTestCase):
|
|||||||
# Make sure that there are non-partial sharding specs in the HLO
|
# Make sure that there are non-partial sharding specs in the HLO
|
||||||
self.assertRegex(hlo, r"sharding={devices=\[[0-9,]+\][0-9,]+}")
|
self.assertRegex(hlo, r"sharding={devices=\[[0-9,]+\][0-9,]+}")
|
||||||
|
|
||||||
@with_and_without_mesh
|
@jtu.with_and_without_mesh
|
||||||
def testMultipleCalls(self, mesh, axis_resources):
|
def testMultipleCalls(self, mesh, axis_resources):
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
assert x.shape == y.shape == (3, 5)
|
assert x.shape == y.shape == (3, 5)
|
||||||
@ -420,7 +392,7 @@ class XMapTest(XMapTestCase):
|
|||||||
for i in range(10):
|
for i in range(10):
|
||||||
self.assertAllClose(f_mapped(x, x), expected)
|
self.assertAllClose(f_mapped(x, x), expected)
|
||||||
|
|
||||||
@with_and_without_mesh
|
@jtu.with_and_without_mesh
|
||||||
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
||||||
def testBufferDonation(self, mesh, axis_resources):
|
def testBufferDonation(self, mesh, axis_resources):
|
||||||
shard = lambda x: x
|
shard = lambda x: x
|
||||||
@ -443,7 +415,7 @@ class XMapTest(XMapTestCase):
|
|||||||
xmap(lambda x: lax.fori_loop(0, 10, lambda _, x: lax.psum(x, 'i'), x),
|
xmap(lambda x: lax.fori_loop(0, 10, lambda _, x: lax.psum(x, 'i'), x),
|
||||||
in_axes=['i', ...], out_axes=['i', ...])(x)
|
in_axes=['i', ...], out_axes=['i', ...])(x)
|
||||||
|
|
||||||
@with_and_without_mesh
|
@jtu.with_and_without_mesh
|
||||||
def testAxisSizes(self, mesh, axis_resources):
|
def testAxisSizes(self, mesh, axis_resources):
|
||||||
result = xmap(lambda: lax.axis_index('i'),
|
result = xmap(lambda: lax.axis_index('i'),
|
||||||
in_axes=(), out_axes=['i', ...],
|
in_axes=(), out_axes=['i', ...],
|
||||||
@ -551,7 +523,7 @@ class XMapTest(XMapTestCase):
|
|||||||
y = jnp.arange(20, dtype=jnp.float32).reshape((4, 5)) / 100
|
y = jnp.arange(20, dtype=jnp.float32).reshape((4, 5)) / 100
|
||||||
jtu.check_grads(f, (x, y), order=2, modes=['fwd'])
|
jtu.check_grads(f, (x, y), order=2, modes=['fwd'])
|
||||||
|
|
||||||
@with_and_without_mesh
|
@jtu.with_and_without_mesh
|
||||||
def testNamedShape(self, mesh, axis_resources):
|
def testNamedShape(self, mesh, axis_resources):
|
||||||
x = np.arange(4,)
|
x = np.arange(4,)
|
||||||
y = 2
|
y = 2
|
||||||
@ -563,7 +535,7 @@ class XMapTest(XMapTestCase):
|
|||||||
self.assertEqual(z.aval.named_shape, {})
|
self.assertEqual(z.aval.named_shape, {})
|
||||||
self.assertEqual(w.aval.named_shape, {})
|
self.assertEqual(w.aval.named_shape, {})
|
||||||
|
|
||||||
@with_and_without_mesh
|
@jtu.with_and_without_mesh
|
||||||
def testBroadcast(self, mesh, axis_resources):
|
def testBroadcast(self, mesh, axis_resources):
|
||||||
x = jnp.asarray(2.0)
|
x = jnp.asarray(2.0)
|
||||||
f = xmap(lambda x: x, in_axes={}, out_axes=['i'],
|
f = xmap(lambda x: x, in_axes={}, out_axes=['i'],
|
||||||
@ -590,7 +562,7 @@ class XMapTestSPMD(SPMDTestMixin, XMapTest):
|
|||||||
raise SkipTest
|
raise SkipTest
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 2), ('z', 2)])
|
@jtu.with_mesh([('x', 2), ('y', 2), ('z', 2)])
|
||||||
def testNestedMeshSPMD(self):
|
def testNestedMeshSPMD(self):
|
||||||
h = xmap(lambda y: (jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b', 'c'))),
|
h = xmap(lambda y: (jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b', 'c'))),
|
||||||
in_axes={0: 'c'}, out_axes=({1: 'c'}, {}),
|
in_axes={0: 'c'}, out_axes=({1: 'c'}, {}),
|
||||||
@ -670,7 +642,7 @@ class NamedRandomTest(XMapTestCase):
|
|||||||
((f"_mesh={mesh}_resources={sorted(axis_resources.items())}",
|
((f"_mesh={mesh}_resources={sorted(axis_resources.items())}",
|
||||||
{"axis_resources": tuple(axis_resources.items()), "mesh": tuple(mesh)})
|
{"axis_resources": tuple(axis_resources.items()), "mesh": tuple(mesh)})
|
||||||
for axis_resources, mesh in schedules({'i': 4, 'j': 6})), subset=True)
|
for axis_resources, mesh in schedules({'i': 4, 'j': 6})), subset=True)
|
||||||
@with_mesh_from_kwargs
|
@jtu.with_mesh_from_kwargs
|
||||||
def testSamplerResourceIndependence(self, distr_sample, axis_resources, mesh):
|
def testSamplerResourceIndependence(self, distr_sample, axis_resources, mesh):
|
||||||
def sample(axis_resources):
|
def sample(axis_resources):
|
||||||
return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=NamedShape(3, i=4, j=6)),
|
return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=NamedShape(3, i=4, j=6)),
|
||||||
@ -743,7 +715,7 @@ class NewPrimitiveTest(XMapTestCase):
|
|||||||
x_explode = x.reshape((3, 3, 3))
|
x_explode = x.reshape((3, 3, 3))
|
||||||
self.assertAllClose(pgather(x, idx, 0), pgather(x_explode, idx, (0, 1)))
|
self.assertAllClose(pgather(x, idx, 0), pgather(x_explode, idx, (0, 1)))
|
||||||
|
|
||||||
@with_and_without_mesh
|
@jtu.with_and_without_mesh
|
||||||
def testGather(self, mesh, axis_resources):
|
def testGather(self, mesh, axis_resources):
|
||||||
if axis_resources and not jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING:
|
if axis_resources and not jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING:
|
||||||
raise SkipTest("pgather over mesh axes without SPMD lowering not implemented")
|
raise SkipTest("pgather over mesh axes without SPMD lowering not implemented")
|
||||||
@ -851,7 +823,7 @@ def gen_axis_names():
|
|||||||
|
|
||||||
def schedules_from_pdot_spec(
|
def schedules_from_pdot_spec(
|
||||||
spec: PdotTestSpec, lhs_shape: Tuple[int], rhs_shape: Tuple[int]
|
spec: PdotTestSpec, lhs_shape: Tuple[int], rhs_shape: Tuple[int]
|
||||||
) -> Generator[Tuple[AxisResources, MeshSpec], None, None]:
|
) -> Generator[Tuple[AxisResources, jtu.MeshSpec], None, None]:
|
||||||
logical_sizes = {
|
logical_sizes = {
|
||||||
name: shape[ax]
|
name: shape[ax]
|
||||||
for shape, in_axes in [(lhs_shape, spec.lhs_in_axes),
|
for shape, in_axes in [(lhs_shape, spec.lhs_in_axes),
|
||||||
@ -862,7 +834,7 @@ def schedules_from_pdot_spec(
|
|||||||
|
|
||||||
class PDotTests(XMapTestCase):
|
class PDotTests(XMapTestCase):
|
||||||
|
|
||||||
@with_mesh([('r1', 2)])
|
@jtu.with_mesh([('r1', 2)])
|
||||||
def testPdotBasic(self):
|
def testPdotBasic(self):
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
return lax.pdot(x, y, 'i')
|
return lax.pdot(x, y, 'i')
|
||||||
@ -880,7 +852,7 @@ class PDotTests(XMapTestCase):
|
|||||||
|
|
||||||
self.assertAllClose(z, jnp.dot(x, y))
|
self.assertAllClose(z, jnp.dot(x, y))
|
||||||
|
|
||||||
@with_mesh([('r1', 2)])
|
@jtu.with_mesh([('r1', 2)])
|
||||||
def testPdotBatching(self):
|
def testPdotBatching(self):
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
return lax.pdot(x, y, 'i')
|
return lax.pdot(x, y, 'i')
|
||||||
@ -898,7 +870,7 @@ class PDotTests(XMapTestCase):
|
|||||||
|
|
||||||
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
|
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
|
||||||
|
|
||||||
@with_mesh([('r1', 2)])
|
@jtu.with_mesh([('r1', 2)])
|
||||||
def testPdotBatchingShardUncontractedDim(self):
|
def testPdotBatchingShardUncontractedDim(self):
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
return lax.pdot(x, y, 'i')
|
return lax.pdot(x, y, 'i')
|
||||||
@ -945,7 +917,7 @@ class PDotTests(XMapTestCase):
|
|||||||
out_axes=[*pdot_spec.batch_names, ...],
|
out_axes=[*pdot_spec.batch_names, ...],
|
||||||
axis_resources=axis_resources)
|
axis_resources=axis_resources)
|
||||||
|
|
||||||
with with_mesh(mesh_data):
|
with jtu.with_mesh(mesh_data):
|
||||||
result = fun(lhs, rhs)
|
result = fun(lhs, rhs)
|
||||||
|
|
||||||
expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums)
|
expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums)
|
||||||
@ -989,7 +961,7 @@ class PDotTests(XMapTestCase):
|
|||||||
out_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes),
|
out_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes),
|
||||||
axis_resources=axis_resources)
|
axis_resources=axis_resources)
|
||||||
|
|
||||||
with with_mesh(mesh_data):
|
with jtu.with_mesh(mesh_data):
|
||||||
lhs_bar, rhs_bar = fun(lhs, rhs, out_bar)
|
lhs_bar, rhs_bar = fun(lhs, rhs, out_bar)
|
||||||
|
|
||||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||||
@ -1062,7 +1034,7 @@ class PDotTests(XMapTestCase):
|
|||||||
|
|
||||||
class XMapErrorTest(jtu.JaxTestCase):
|
class XMapErrorTest(jtu.JaxTestCase):
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testRepeatedAxisResource(self):
|
def testRepeatedAxisResource(self):
|
||||||
def f(v):
|
def f(v):
|
||||||
return v * 4
|
return v * 4
|
||||||
@ -1070,7 +1042,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
|||||||
fxy = xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
fxy = xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
||||||
axis_resources={'a': ('x', 'x')})
|
axis_resources={'a': ('x', 'x')})
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testNestedDifferentResources(self):
|
def testNestedDifferentResources(self):
|
||||||
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
|
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
|
||||||
def f(x):
|
def f(x):
|
||||||
@ -1089,7 +1061,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(ValueError, "Failed to infer size of axes: i."):
|
with self.assertRaisesRegex(ValueError, "Failed to infer size of axes: i."):
|
||||||
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...])({})
|
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...])({})
|
||||||
|
|
||||||
@with_mesh([('x', 2), ('y', 2)])
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||||
def testAxesNotDivisibleByResources(self):
|
def testAxesNotDivisibleByResources(self):
|
||||||
with self.assertRaisesRegex(ValueError, r"Size of axis i \(5\) is not divisible.*"
|
with self.assertRaisesRegex(ValueError, r"Size of axis i \(5\) is not divisible.*"
|
||||||
r"\(\('x', 'y'\), 4 in total\)"):
|
r"\(\('x', 'y'\), 4 in total\)"):
|
||||||
@ -1154,7 +1126,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(TypeError, error):
|
with self.assertRaisesRegex(TypeError, error):
|
||||||
fm(x, y)
|
fm(x, y)
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testResourceConflictArgs(self):
|
def testResourceConflictArgs(self):
|
||||||
fm = xmap(lambda x: lax.psum(x, ('a', 'b')),
|
fm = xmap(lambda x: lax.psum(x, ('a', 'b')),
|
||||||
in_axes=['a', 'b'], out_axes=[],
|
in_axes=['a', 'b'], out_axes=[],
|
||||||
@ -1166,7 +1138,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(JAXTypeError, error):
|
with self.assertRaisesRegex(JAXTypeError, error):
|
||||||
fm(x)
|
fm(x)
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testResourceConflictInner(self):
|
def testResourceConflictInner(self):
|
||||||
fm = xmap(lambda x, y: x + y,
|
fm = xmap(lambda x, y: x + y,
|
||||||
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
|
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
|
||||||
@ -1178,7 +1150,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(JAXTypeError, error):
|
with self.assertRaisesRegex(JAXTypeError, error):
|
||||||
fm(x, y)
|
fm(x, y)
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testResourceConflictOut(self):
|
def testResourceConflictOut(self):
|
||||||
fm = xmap(lambda x, y: x,
|
fm = xmap(lambda x, y: x,
|
||||||
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
|
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
|
||||||
@ -1191,7 +1163,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(JAXTypeError, error):
|
with self.assertRaisesRegex(JAXTypeError, error):
|
||||||
fm(x, y)
|
fm(x, y)
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testResourceConflictNestArgs(self):
|
def testResourceConflictNestArgs(self):
|
||||||
f = xmap(lambda x: x, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'x'})
|
f = xmap(lambda x: x, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'x'})
|
||||||
h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
|
h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
|
||||||
@ -1202,7 +1174,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(JAXTypeError, error):
|
with self.assertRaisesRegex(JAXTypeError, error):
|
||||||
h(x)
|
h(x)
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testResourceConflictNestInner(self):
|
def testResourceConflictNestInner(self):
|
||||||
f = xmap(lambda x: lax.axis_index('i') + x,
|
f = xmap(lambda x: lax.axis_index('i') + x,
|
||||||
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
|
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
|
||||||
@ -1214,7 +1186,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
|||||||
with self.assertRaisesRegex(JAXTypeError, error):
|
with self.assertRaisesRegex(JAXTypeError, error):
|
||||||
h(x)
|
h(x)
|
||||||
|
|
||||||
@with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
def testResourceConflictNestOut(self):
|
def testResourceConflictNestOut(self):
|
||||||
f = xmap(lambda x: x,
|
f = xmap(lambda x: x,
|
||||||
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
|
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user