mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00

The goal of this change is to catch PRs that introduce new warnings sooner. To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable. Add code to suppress some new warnings uncovered in CI. PiperOrigin-RevId: 678352286
652 lines
25 KiB
Python
652 lines
25 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Tests for the jax2tf conversion of pjit.
|
|
|
|
To verify that the tests do run indeed on multiple devices you can run
|
|
|
|
perftools/gputools/profiler/jfprof.sh jax/experimental/jax2tf/tests:sharding_test_tpu -- -c opt --test_filter=ShardingTest.test_shmap_all_to_all --test_arg=--vmodule=jax2tf=3 --
|
|
|
|
"""
|
|
from collections.abc import Sequence
|
|
import contextlib
|
|
from functools import partial
|
|
import logging
|
|
import math
|
|
import os
|
|
import re
|
|
from typing import Any
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
|
|
import jax
|
|
from jax._src import compiler
|
|
from jax._src import config
|
|
from jax._src import test_util as jtu
|
|
from jax._src import xla_bridge
|
|
from jax import lax
|
|
from jax.experimental import jax2tf
|
|
from jax.experimental import pjit
|
|
from jax.experimental.shard_map import shard_map
|
|
from jax.sharding import NamedSharding
|
|
from jax.sharding import Mesh
|
|
from jax.sharding import PartitionSpec as P
|
|
import jax.numpy as jnp
|
|
|
|
import numpy as np
|
|
|
|
import tensorflow as tf
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
# Must come after initializing the flags
|
|
from jax.experimental.jax2tf.tests import tf_test_util
|
|
|
|
_exit_stack = contextlib.ExitStack()
|
|
topology = None
|
|
|
|
def setUpModule():
|
|
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
|
|
|
global topology
|
|
if jtu.test_device_matches(["tpu"]):
|
|
with jtu.ignore_warning(message="the imp module is deprecated"):
|
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
|
|
tf.config.experimental_connect_to_cluster(resolver)
|
|
# Do TPU init at beginning since it will wipe out all HBMs.
|
|
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
|
|
else:
|
|
topology = None
|
|
|
|
def tearDownModule():
|
|
_exit_stack.close()
|
|
|
|
|
|
class ShardingTest(tf_test_util.JaxToTfTestCase):
|
|
"""Tests that inspect the HLO for the sharding annotations.
|
|
"""
|
|
def setUp(self):
|
|
super().setUp()
|
|
if jtu.test_device_matches(["gpu"]):
|
|
raise unittest.SkipTest("Sharding HLO tests not useful for GPU")
|
|
|
|
if len(jax.devices()) < 2:
|
|
raise unittest.SkipTest("Test requires at least 2 local devices")
|
|
self.devices = np.array(jax.devices()[:2]) # use 2 devices
|
|
|
|
self.warning_ctx = jtu.ignore_warning(
|
|
message="jax2tf.convert with native_serialization=False is deprecated"
|
|
)
|
|
self.warning_ctx.__enter__()
|
|
|
|
def tearDown(self):
|
|
self.warning_ctx.__exit__(None, None, None)
|
|
super().tearDown()
|
|
|
|
def log_jax_hlo(self, f_jax, args: Sequence[Any], *,
|
|
num_replicas=1, num_partitions=2):
|
|
"""Log the HLO generated from JAX before and after optimizations"""
|
|
jax_comp = f_jax.lower(*args).compiler_ir(dialect="stablehlo")
|
|
jax_hlo = str(jax_comp)
|
|
logging.info("[%s] got JAX HLO %s", self._testMethodName, jax_hlo)
|
|
|
|
# We only dump JAX optimized code on the TPU
|
|
if jtu.test_device_matches(["tpu"]):
|
|
backend = xla_bridge.get_backend()
|
|
device_assignment = np.arange(num_partitions * num_replicas)
|
|
device_assignment = np.reshape(device_assignment, (-1, num_partitions))
|
|
use_spmd_partitioning = num_partitions > 1
|
|
compile_options = compiler.get_compile_options(
|
|
num_replicas=num_replicas,
|
|
num_partitions=num_partitions,
|
|
device_assignment=device_assignment,
|
|
use_spmd_partitioning=use_spmd_partitioning,
|
|
)
|
|
jax_optimized_hlo = backend.compile(
|
|
jax_hlo, compile_options).hlo_modules()[0].to_string()
|
|
logging.info("[%s] got JAX optimized HLO for platform %s %s",
|
|
self._testMethodName, backend.platform, jax_optimized_hlo)
|
|
|
|
def device_assignment(self,
|
|
computation_shape=(1, 1, 1, 2),
|
|
num_replicas=1):
|
|
self.assertEqual(jtu.device_under_test(), "tpu")
|
|
return tf.tpu.experimental.DeviceAssignment.build(
|
|
topology, computation_shape=computation_shape,
|
|
num_replicas=num_replicas)
|
|
|
|
def tf_hlo(self, f_tf, args_tf: Sequence[Any]) -> str:
|
|
"""Get the unoptimized HLO from TF"""
|
|
f_tf_fun = tf.function(f_tf, autograph=False, jit_compile=True)
|
|
logging.info("[%s] Got TF graph %s",
|
|
self._testMethodName,
|
|
f_tf_fun.get_concrete_function(*args_tf).graph.as_graph_def())
|
|
device_name = f"/device:{jtu.device_under_test().upper()}:0"
|
|
tf_hlo_generator = f_tf_fun.experimental_get_compiler_ir(*args_tf)
|
|
tf_hlo = tf_hlo_generator(
|
|
stage="hlo", platform_name=jtu.device_under_test().upper()
|
|
)
|
|
logging.info("[%s] got TF HLO %s", self._testMethodName, tf_hlo)
|
|
# TODO(necula): TensorFlow doesn't support getting the optimized_hlo on TFRT
|
|
# TPU devices. But it doesn't seem like we're using it anyway.
|
|
#
|
|
# tf_optimized_hlo = tf_hlo_generator(stage="optimized_hlo",
|
|
# platform_name=platform_name)
|
|
# logging.info("[%s] got TF optimized HLO for %s: %s", self._testMethodName,
|
|
# platform_name, tf_optimized_hlo)
|
|
# Before we check, we drop the metadata= at the end of tf_hlo
|
|
return re.sub(r'metadata=.*', '', tf_hlo)
|
|
|
|
def GEQ(self, value):
|
|
# Construct an expected >= value. See `check_sharding`.
|
|
return (">=", value)
|
|
|
|
def check_sharding(self, f_tf, args_tf: Sequence[Any], *,
|
|
checks=()):
|
|
"""Check the sharding in TF.
|
|
|
|
Args:
|
|
f_tf: the TF callable
|
|
args_tf: the TF args
|
|
checks: a list of tuples. The first element is a regular expression, the
|
|
second element is an integer representing the expected number of
|
|
occurrences of the regular expression in the TF HLO. As a special case,
|
|
the second element can be the result of `self.GEQ(v)` to check that
|
|
the number of occurrences is greater or equal to a value.
|
|
"""
|
|
tf_hlo = self.tf_hlo(f_tf, args_tf)
|
|
for check_re, expected_count in checks:
|
|
count = len(re.findall(check_re, tf_hlo))
|
|
if isinstance(expected_count, int):
|
|
self.assertEqual(
|
|
count, expected_count,
|
|
(f"regular expression `{check_re}` expected to occur "
|
|
f"{expected_count} times but occurs {count} times in "
|
|
f"the TF HLO.\nThis is the TF HLO:\n{tf_hlo}"))
|
|
elif isinstance(expected_count, tuple) and expected_count[0] == ">=":
|
|
self.assertGreaterEqual(
|
|
count, expected_count[1],
|
|
(f"regular expression `{check_re}` expected to occur "
|
|
f"at least {expected_count[1]} times but occurs {count} times in "
|
|
f"the TF HLO.\nThis is the TF HLO:\n{tf_hlo}"))
|
|
else:
|
|
assert False
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[
|
|
dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}",
|
|
in_shardings=in_shardings, out_shardings=out_shardings)
|
|
for in_shardings in ("missing", None, "P")
|
|
for out_shardings in ("missing", None, "P")
|
|
])
|
|
@jtu.with_mesh([("x", 2)])
|
|
def test_pjit_basic(self, in_shardings="P", out_shardings="P"):
|
|
# Ensure that we can distinguish the inputs and outputs by shape
|
|
def f_jax(x): # f32[10,20] -> f32[20,10]
|
|
return jnp.sin(x.T)
|
|
|
|
pjit_kwargs = {}
|
|
if in_shardings != "missing":
|
|
pjit_kwargs["in_shardings"] = (P(None, "x") if in_shardings == "P" else None)
|
|
if out_shardings != "missing":
|
|
pjit_kwargs["out_shardings"] = (P("x", None) if out_shardings == "P" else None)
|
|
f_jax = pjit.pjit(f_jax, **pjit_kwargs)
|
|
|
|
x_shape = (10, 20)
|
|
x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
|
|
|
|
self.log_jax_hlo(f_jax, [x], num_partitions=2)
|
|
|
|
@tf.function(autograph=False, jit_compile=True)
|
|
def f_tf(x):
|
|
f_converted = jax2tf.convert(f_jax)
|
|
if jtu.test_device_matches(["tpu"]):
|
|
return tf.compat.v1.tpu.rewrite(
|
|
f_converted, [tf.convert_to_tensor(x)],
|
|
device_assignment=self.device_assignment(
|
|
computation_shape=[1, 1, 1, 2],
|
|
))[0]
|
|
else:
|
|
return f_converted(x)
|
|
|
|
# Annotation count for the input
|
|
count_in_P = 1 if in_shardings == "P" else 0
|
|
if config.jax2tf_default_native_serialization.value:
|
|
# With native serialization even unspecified in_shardings turn into replicated
|
|
count_in_replicated = 1 if in_shardings in [None, "missing"] else 0
|
|
else:
|
|
count_in_replicated = 1 if in_shardings is None else 0
|
|
# Annotation count for the output
|
|
count_out_P = 1 if out_shardings == "P" else 0
|
|
count_out_replicated = 1 if out_shardings is None else 0
|
|
|
|
self.check_sharding(
|
|
jax2tf.convert(f_jax), [x],
|
|
checks=[
|
|
# The argument
|
|
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]",
|
|
count_in_P),
|
|
# The result
|
|
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]",
|
|
count_out_P),
|
|
])
|
|
# TODO(b/326476605): Change the condition below if required.
|
|
if in_shardings not in [None, "missing"] and out_shardings is not None:
|
|
self.check_sharding(
|
|
jax2tf.convert(f_jax), [x],
|
|
checks=[
|
|
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated",
|
|
count_in_replicated),
|
|
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated",
|
|
count_out_replicated),
|
|
(r"custom_call_target.*Sharding",
|
|
count_in_P + count_in_replicated + count_out_P + count_out_replicated),
|
|
])
|
|
|
|
res_jax = f_jax(x)
|
|
res_tf = f_tf(x)
|
|
self.assertAllClose(res_tf.numpy(), res_jax)
|
|
|
|
@jtu.with_mesh([("x", 2)])
|
|
def test_pjit_variable_arg(self):
|
|
# The first argument is a tf.Variable
|
|
@partial(pjit.pjit, in_shardings=(P(None, "x"), P("x", None)),
|
|
out_shardings=None)
|
|
def f_jax(x, y): # f32[10,20] , f32[20,30] -> f32[10,30]
|
|
return x @ y
|
|
|
|
shape_x = (10, 20)
|
|
x = np.arange(np.prod(shape_x), dtype=np.float32).reshape(shape_x)
|
|
shape_y = (20, 30)
|
|
y = np.arange(np.prod(shape_y), dtype=np.float32).reshape(shape_y)
|
|
|
|
self.log_jax_hlo(f_jax, [x, y], num_partitions=2)
|
|
|
|
x_v = tf.Variable(x)
|
|
f_tf = lambda y: jax2tf.convert(f_jax)(x_v, y)
|
|
|
|
self.check_sharding(
|
|
f_tf, [y],
|
|
checks=[
|
|
# The variable argument
|
|
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", 1),
|
|
# The y argument
|
|
(r"f32\[20,30\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", 1),
|
|
# The output sharding
|
|
(r"f32\[10,30\].*custom_call_target.*Sharding.*sharding.*replicated", 1),
|
|
# No other annotations
|
|
(r"custom_call_target.*Sharding", 3)
|
|
])
|
|
|
|
@jtu.with_mesh([("x", 2)])
|
|
def test_pjit_closed_over_const(self):
|
|
x = np.ones((10, 20), dtype=np.float32)
|
|
const = jnp.full((10, 20), 7, dtype=np.float32)
|
|
|
|
@partial(pjit.pjit, in_shardings=(P("x"),), out_shardings=None)
|
|
def f_jax(x): # f32[10,20] -> f32[20,10]
|
|
return (x * const).T
|
|
|
|
@tf.function(autograph=False, jit_compile=True)
|
|
def f_tf(x):
|
|
f_converted = jax2tf.convert(f_jax)
|
|
if jtu.test_device_matches(["tpu"]):
|
|
return tf.compat.v1.tpu.rewrite(
|
|
f_converted, [tf.convert_to_tensor(x)],
|
|
device_assignment=self.device_assignment(
|
|
computation_shape=[1, 1, 1, 2])
|
|
)[0]
|
|
else:
|
|
return f_converted(x)
|
|
|
|
self.check_sharding(
|
|
jax2tf.convert(f_jax), [x],
|
|
checks=[
|
|
# x
|
|
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]",
|
|
1),
|
|
# The result
|
|
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated",
|
|
self.GEQ(1)),
|
|
])
|
|
|
|
res_jax = f_jax(x)
|
|
res_tf = f_tf(x)
|
|
self.assertAllClose(res_tf, res_jax)
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[
|
|
dict(testcase_name=f"_nested_pjit={nested_pjit}_constraint={constraint}_poly={poly}",
|
|
nested_pjit=nested_pjit, constraint=constraint, poly=poly)
|
|
# We add a constraint either with a nested pjit or with a sharding_constraint
|
|
for nested_pjit in (True, False)
|
|
for constraint in (None, "P")
|
|
for poly in (None, "2*b1,_", "_,b2", "2*b1,b2")
|
|
])
|
|
@jtu.with_mesh([("x", 2)])
|
|
def test_pjit_sharding_constraint(self, nested_pjit=True, constraint="P", poly="2*b1,b2"):
|
|
constraint_sharding = P("x", None) if constraint == "P" else None
|
|
@partial(pjit.pjit, in_shardings=None,
|
|
out_shardings=None)
|
|
def f_jax(x): # x: f32[10, 20], optionally some axes as polymorphic
|
|
y = jnp.concatenate([x, x], axis=1) # y: f32[10, 40]
|
|
if nested_pjit:
|
|
y = pjit.pjit(lambda y: y, in_shardings=constraint_sharding,
|
|
out_shardings=constraint_sharding)(y)
|
|
else:
|
|
y = jax.lax.with_sharding_constraint(y, constraint_sharding)
|
|
return jnp.concatenate([y, y], axis=1) # res: f32[10, 80]
|
|
|
|
shape = (10, 20)
|
|
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
self.log_jax_hlo(f_jax, [x], num_partitions=2)
|
|
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=poly)
|
|
|
|
# If we use a pjit then we see two constraints, otherwise only 1
|
|
count_inner_sharding = (2 if nested_pjit else 1) if constraint == "P" else 0
|
|
count_inner_replicated = (2 if nested_pjit else 1) if constraint != "P" else 0
|
|
self.check_sharding(
|
|
f_tf, [x],
|
|
checks=[
|
|
# The input argument
|
|
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", 1),
|
|
# The y argument
|
|
(r"f32\[10,40\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]",
|
|
count_inner_sharding),
|
|
(r"f32\[10,40\].*custom_call_target.*Sharding.*sharding.*replicated",
|
|
count_inner_replicated),
|
|
# The output sharding
|
|
(r"f32\[10,80\].*custom_call_target.*Sharding.*sharding.*replicated", 1),
|
|
# No other annotations
|
|
(r"custom_call_target.*Sharding", 2 + count_inner_sharding + count_inner_replicated)
|
|
])
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[
|
|
dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}",
|
|
in_shardings=in_shardings, out_shardings=out_shardings)
|
|
for in_shardings in ("missing", None, "P")
|
|
for out_shardings in ("missing", None, "P")
|
|
])
|
|
def test_grad_pjit(self, in_shardings="P", out_shardings=None):
|
|
if not config.jax2tf_default_native_serialization.value:
|
|
self.skipTest("TODO: failure in non-native serialization")
|
|
local_devices = list(jax.local_devices())
|
|
size = 2
|
|
if len(local_devices) < size:
|
|
raise unittest.SkipTest(f"Test requires {size} local devices")
|
|
mesh_devices = np.array(local_devices[:size]).reshape((2,))
|
|
mesh = jax.sharding.Mesh(mesh_devices, ("x",))
|
|
def f_jax(x): # x: f32[10,20] -> f32[20,10]
|
|
return jnp.sin(x.T)
|
|
|
|
pjit_kwargs = {}
|
|
if in_shardings != "missing":
|
|
pjit_kwargs["in_shardings"] = (
|
|
NamedSharding(mesh, P(None, "x")) if in_shardings == "P" else None)
|
|
if out_shardings != "missing":
|
|
pjit_kwargs["out_shardings"] = (
|
|
NamedSharding(mesh, P("x", None)) if out_shardings == "P" else None)
|
|
f_jax = pjit.pjit(f_jax, **pjit_kwargs)
|
|
x_shape = (10, 20)
|
|
x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
|
|
|
|
def f_grad_tf(x_v, res_ct):
|
|
with tf.GradientTape(persistent=True) as tape:
|
|
tape.watch(x_v)
|
|
with tf.GradientTape() as tape2:
|
|
tape2.watch(x_v)
|
|
res_tf = jax2tf.convert(f_jax)(x_v)
|
|
dy_dx = tape.gradient(res_tf, x_v, output_gradients=res_ct)
|
|
d2y_dx2 = tape.gradient(dy_dx, x_v)
|
|
return d2y_dx2
|
|
|
|
# Annotation count for the primal input and the grad output
|
|
count_in_P = self.GEQ(2) if in_shardings == "P" else 0
|
|
if config.jax2tf_default_native_serialization.value:
|
|
# With native serialization even unspecified shardings turn into replicated
|
|
count_in_replicated = self.GEQ(2) if in_shardings in [None, "missing"] else 0
|
|
else:
|
|
count_in_replicated = self.GEQ(2) if in_shardings is None else 0
|
|
# Annotation count for the contangent input
|
|
count_out_P = self.GEQ(1) if out_shardings == "P" else 0
|
|
if config.jax2tf_default_native_serialization.value:
|
|
# With native serialization even unspecified shardings turn into replicated
|
|
count_out_replicated = self.GEQ(1) if out_shardings in [None, "missing"] else 0
|
|
else:
|
|
count_out_replicated = self.GEQ(1) if out_shardings is None else 0
|
|
|
|
self.check_sharding(f_grad_tf, [x, x.T],
|
|
checks=[
|
|
# The input primal argument, and the output grad
|
|
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", count_in_P),
|
|
# The primal result, and the input cotangent
|
|
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P),
|
|
])
|
|
# TODO(b/326476605): Change the condition below if required.
|
|
if out_shardings not in [None, "missing"] and in_shardings not in [None, "missing"]:
|
|
self.check_sharding(f_grad_tf, [x, x.T],
|
|
checks=[
|
|
(r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", count_in_replicated),
|
|
# The primal result, and the input cotangent
|
|
(r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P),
|
|
])
|
|
|
|
def test_grad_sharding_different_mesh(self):
|
|
# Convert with two similar meshes, the only difference being
|
|
# the order of the devices. grad should not fail.
|
|
# https://github.com/jax-ml/jax/issues/21314
|
|
devices = jax.local_devices()[:2]
|
|
if len(devices) < 2:
|
|
raise unittest.SkipTest("Test requires 2 local devices")
|
|
def f_jax(x):
|
|
return jnp.sum(x * 2.)
|
|
|
|
mesh = Mesh(devices, "i")
|
|
# The same mesh with reversed order of devices
|
|
mesh_rev = Mesh(list(reversed(devices)), "i")
|
|
shardings = NamedSharding(mesh, jax.sharding.PartitionSpec(("i",)))
|
|
shardings_rev = NamedSharding(mesh_rev, jax.sharding.PartitionSpec(("i",)))
|
|
|
|
f_tf = tf.function(jax2tf.convert(pjit.pjit(f_jax, in_shardings=shardings)),
|
|
autograph=False)
|
|
f_tf_rev = tf.function(jax2tf.convert(pjit.pjit(f_jax, in_shardings=shardings_rev)),
|
|
autograph=False)
|
|
inp = np.ones((2, 4), dtype=np.float32)
|
|
|
|
input_v = tf.Variable(inp)
|
|
with tf.GradientTape(persistent=True) as tape:
|
|
tape.watch(input_v)
|
|
res_tf = f_tf(input_v)
|
|
g = tape.gradient(res_tf, input_v)
|
|
|
|
with tf.GradientTape(persistent=True) as tape:
|
|
tape.watch(input_v)
|
|
res_tf_rev = f_tf_rev(input_v)
|
|
g_rev = tape.gradient(res_tf_rev, input_v)
|
|
self.assertAllClose(g, g_rev)
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[
|
|
dict(testcase_name=f"_func={func}", func=func)
|
|
for func in ("pjit_sharded", "pjit_replicated",
|
|
"nested_pjit_sharded", "nested_pjit_replicated")
|
|
])
|
|
def test_pjit_eager_error(self, func="pjit_sharded"):
|
|
if config.jax2tf_default_native_serialization.value:
|
|
raise unittest.SkipTest("There is no error in eager mode for native serialization")
|
|
|
|
# Define some test functions
|
|
@partial(pjit.pjit, in_shardings=(P("x"),),
|
|
out_shardings=None)
|
|
def f_pjit_sharded(a):
|
|
return a + a
|
|
|
|
@partial(pjit.pjit, in_shardings=None,
|
|
out_shardings=None)
|
|
def f_pjit_replicated(a):
|
|
return a + a
|
|
|
|
def f_nested_pjit_sharded(a):
|
|
return a + pjit.pjit(jnp.sin, in_shardings=(P("x"),), out_shardings=None)(a)
|
|
|
|
def f_nested_pjit_replicated(a):
|
|
return a + pjit.pjit(jnp.sin, in_shardings=None, out_shardings=None)(a)
|
|
|
|
shape = (8, 10)
|
|
a = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
if func == "pjit_sharded":
|
|
f_jax = f_pjit_sharded
|
|
elif func == "pjit_replicated":
|
|
f_jax = f_pjit_replicated
|
|
elif func == "nested_pjit_sharded":
|
|
f_jax = f_nested_pjit_sharded
|
|
elif func == "nested_pjit_replicated":
|
|
f_jax = f_nested_pjit_replicated
|
|
else:
|
|
assert False
|
|
|
|
with Mesh(self.devices, axis_names=("x",)):
|
|
_ = f_jax(a)
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"function with sharded arguments or results must be used under a `tf.function` context"):
|
|
jax2tf.convert(f_jax)(a)
|
|
|
|
@jtu.ignore_warning(category=UserWarning,
|
|
message="all_to_all .* are only implemented properly for TPUs and GPUs .*")
|
|
def test_shmap_all_to_all(self):
|
|
if jtu.test_device_matches(["cpu"]):
|
|
raise unittest.SkipTest("TODO(b/268295912): ShardingRemover crash")
|
|
|
|
mesh = Mesh(self.devices, axis_names=('x'))
|
|
a = np.arange(4 * 4, dtype=np.float32).reshape((4, 4))
|
|
|
|
@partial(pjit.pjit,
|
|
in_shardings=(P('x', None),), out_shardings=P(None, 'x'))
|
|
@partial(shard_map, mesh=mesh,
|
|
in_specs=(P('x', None),), out_specs=P(None, 'x'))
|
|
def f_jax(b): # b: f32[2, 4]
|
|
return lax.all_to_all(b, 'x', split_axis=1, concat_axis=1, tiled=True)
|
|
|
|
@tf.function(autograph=False, jit_compile=True)
|
|
def f_tf(a):
|
|
f_converted = jax2tf.convert(f_jax, native_serialization=True)
|
|
if jtu.test_device_matches(["tpu"]):
|
|
return tf.compat.v1.tpu.rewrite(
|
|
f_converted, [tf.convert_to_tensor(a)],
|
|
device_assignment=self.device_assignment(
|
|
computation_shape=[1, 1, 1, 2])
|
|
)[0]
|
|
else:
|
|
return f_converted(a)
|
|
|
|
with mesh:
|
|
res_jax = f_jax(a) # res: f32[2, 8]
|
|
b0, b1 = np.split(a, 2, axis=0) # The shard_map in_specs splits on axis 0
|
|
b00, b01 = np.split(b0, 2, axis=1) # split_axis=1
|
|
b10, b11 = np.split(b1, 2, axis=1)
|
|
b0 = np.concatenate([b00, b10], axis=1) # concat_axis=1
|
|
b1 = np.concatenate([b01, b11], axis=1)
|
|
res = np.concatenate([b0, b1], axis=1) # out_specs concatenates on axis 1
|
|
self.assertAllClose(res_jax, res)
|
|
res_tf = f_tf(a)
|
|
self.assertAllClose(res_tf, res_jax)
|
|
|
|
# TODO(b/274648842): Failed to GetCompilerIr
|
|
# self.check_sharding(
|
|
# jax2tf.convert(f_jax, native_serialization=True), [a],
|
|
# checks=[])
|
|
|
|
@unittest.skip("TODO(b/268295912): ShardingRemover crash,on all platforms!!!")
|
|
def test_repro_xla_bug_shmap_collective_permute(self):
|
|
mesh = Mesh(self.devices, axis_names=('x'))
|
|
|
|
@partial(pjit.pjit,
|
|
in_shardings=(P('x', None),), out_shardings=P('x', None))
|
|
@partial(shard_map, mesh=mesh,
|
|
in_specs=(P('x', None),), out_specs=P('x', None))
|
|
def f_jax(b): # b: f32[2, 4]
|
|
axis_size = lax.psum(1, 'x')
|
|
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
|
|
return lax.ppermute(b, 'x', perm=perm)
|
|
|
|
with mesh:
|
|
a = np.arange(4 * 4).reshape((4, 4))
|
|
res_jax = f_jax(a)
|
|
b0, b1 = np.split(a, 2, axis=0) # The shard_map splits on axis 0
|
|
b0, b1 = b1, b0
|
|
expected = np.concatenate([b0, b1], axis=0) # out_specs concatenates on axis 0
|
|
self.assertAllClose(res_jax, expected)
|
|
|
|
# XLA bug: invoke the f_tf without tpu.replicate
|
|
f_tf = tf.function(
|
|
jax2tf.convert(f_jax, native_serialization=True),
|
|
autograph=False, jit_compile=True)
|
|
|
|
res_tf = f_tf(a)
|
|
self.assertAllClose(res_tf, expected)
|
|
|
|
@jtu.parameterized_filterable(
|
|
kwargs=[
|
|
dict(testcase_name=f"_poly={poly}", poly=poly)
|
|
for poly in (None, "2*b1,_", "_,b2", "2*b1,b2")
|
|
])
|
|
def test_shmap_collective_permute(self, poly=None):
|
|
if jtu.test_device_matches(["cpu"]):
|
|
raise unittest.SkipTest("TODO(b/268295912): ShardingRemover crash")
|
|
mesh = Mesh(self.devices, axis_names=('x'))
|
|
a = np.arange(4 * 4, dtype=np.float32).reshape((4, 4))
|
|
|
|
@partial(pjit.pjit,
|
|
in_shardings=(P('x', None),), out_shardings=P('x', None))
|
|
@partial(shard_map, mesh=mesh,
|
|
in_specs=(P('x', None),), out_specs=P('x', None))
|
|
def f_jax(b): # b: f32[2, 4]
|
|
axis_size = lax.psum(1, 'x')
|
|
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
|
|
return lax.ppermute(b, 'x', perm=perm)
|
|
|
|
@tf.function(autograph=False, jit_compile=True)
|
|
def f_tf(a):
|
|
f_converted = jax2tf.convert(f_jax, native_serialization=True,
|
|
polymorphic_shapes=poly)
|
|
if jtu.test_device_matches(["tpu"]):
|
|
res = tf.compat.v1.tpu.rewrite(
|
|
f_converted, [tf.convert_to_tensor(a)],
|
|
device_assignment=self.device_assignment(
|
|
computation_shape=[1, 1, 1, 2])
|
|
)[0]
|
|
else:
|
|
res = f_converted(a)
|
|
return res
|
|
|
|
with mesh:
|
|
res_jax = f_jax(a)
|
|
b0, b1 = np.split(a, 2, axis=0) # The shard_map splits on axis 0
|
|
b0, b1 = b1, b0
|
|
expected = np.concatenate([b0, b1], axis=0) # out_specs concatenates on axis 0
|
|
self.assertAllClose(res_jax, expected)
|
|
res_tf = f_tf(a)
|
|
self.assertAllClose(res_tf, expected)
|
|
# TODO(b/274648842): Failed to GetCompilerIr
|
|
# self.check_sharding(
|
|
# jax2tf.convert(f_jax, native_serialization=True), [a],
|
|
# checks=[])
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|