[jax2tf] Add more sharding tests with shape polymorphism

PiperOrigin-RevId: 521471546
This commit is contained in:
George Necula 2023-04-03 08:54:22 -07:00 committed by jax authors
parent ff313a37a2
commit 05249ec770

View File

@ -11,7 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the jax2tf conversion of pjit."""
"""Tests for the jax2tf conversion of pjit.
To verify that the tests do run indeed on multiple devices you can run
perftools/gputools/profiler/jfprof.sh jax/experimental/jax2tf/tests:sharding_test_tpu -- -c opt --test_filter=ShardingTest.test_shmap_all_to_all --test_arg=--vmodule=jax2tf=3 --
"""
import contextlib
from functools import partial
import logging
@ -72,10 +78,6 @@ def tearDownModule():
class ShardingTest(tf_test_util.JaxToTfTestCase):
"""Tests that inspect the HLO for the sharding annotations.
To verify that the tests do run indeed on multiple devices you can run
perftools/gputools/profiler/jfprof.sh jax/experimental/jax2tf/tests:sharding_test_tpu -- -c opt --test_filter=ShardingTest.test_shmap_all_to_all --test_arg=--vmodule=jax2tf=3 --
"""
def setUp(self):
super().setUp()
@ -278,7 +280,6 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
(r"custom_call_target.*Sharding", 3)
])
@jtu.with_mesh([("x", 2)])
def test_pjit_closed_over_const(self):
x = np.ones((10, 20), dtype=np.float32)
@ -316,18 +317,21 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(res_tf, res_jax)
@parameterized.named_parameters(
dict(testcase_name=f"_nested_pjit={nested_pjit}_constraint={constraint=}",
nested_pjit=nested_pjit)
dict(testcase_name=f"_nested_pjit={nested_pjit}_constraint={constraint=}_poly={poly}",
nested_pjit=nested_pjit, constraint=constraint, poly=poly)
# We add a constraint either with a nested pjit or with a sharding_constraint
for nested_pjit in (True, False)
for constraint in (None, "P")
for poly in (None, "b1,_", "_,b2", "b1,b2")
)
@jtu.with_mesh([("x", 2)])
def test_pjit_sharding_constraint(self, nested_pjit=True, constraint="P"):
def test_pjit_sharding_constraint(self, nested_pjit=True, constraint="P", poly=None):
if poly is not None:
raise unittest.SkipTest("TODO: Sharding custom calls lack shape refinement")
constraint_sharding = P("x", None) if constraint == "P" else None
@partial(pjit.pjit, in_shardings=None,
out_shardings=None)
def f_jax(x): # x: f32[10, 20]
def f_jax(x): # x: f32[10, 20], optionally some axes as polymorphic
y = jnp.concatenate([x, x], axis=1) # y: f32[10, 40]
if nested_pjit:
y = pjit.pjit(lambda y: y, in_shardings=constraint_sharding,
@ -340,10 +344,11 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
self.log_jax_hlo(f_jax, [x], num_partitions=2)
f_tf = jax2tf.convert(f_jax)
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=poly)
# If we use a pjit then we see two constraints, otherwise only 1
count_inner_sharding = 2 if nested_pjit else 1
count_inner_sharding = (2 if nested_pjit else 1) if constraint == "P" else 0
count_inner_replicated = (2 if nested_pjit else 1) if constraint != "P" else 0
self.check_sharding(
f_tf, [x],
checks=[
@ -352,13 +357,14 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
# The y argument
(r"f32\[10,40\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]",
count_inner_sharding),
(r"f32\[10,40\].*custom_call_target.*Sharding.*sharding.*replicated",
count_inner_replicated),
# The output sharding
(r"f32\[10,80\].*custom_call_target.*Sharding.*sharding.*replicated", 1),
# No other annotations
(r"custom_call_target.*Sharding", 2 + count_inner_sharding)
(r"custom_call_target.*Sharding", 2 + count_inner_sharding + count_inner_replicated)
])
@parameterized.named_parameters(
dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}",
in_shardings=in_shardings, out_shardings=out_shardings)
@ -660,7 +666,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
# jax2tf.convert(f_jax, native_serialization=True), [a],
# checks=[])
@unittest.skip("TODO(b/268295912): ShardingRemover crash")
@unittest.skip("TODO(b/268295912): ShardingRemover crash,on all platforms!!!")
def test_repro_xla_bug_shmap_collective_permute(self):
mesh = Mesh(self.devices, axis_names=('x'))
@ -689,9 +695,15 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
res_tf = f_tf(a)
self.assertAllClose(res_tf, expected)
def test_shmap_collective_permute(self):
@parameterized.named_parameters(
dict(testcase_name=f"_poly={poly}", poly=poly)
for poly in (None, "b1,_", "_,b2", "b1,b2")
)
def test_shmap_collective_permute(self, poly=None):
if jtu.device_under_test() == "cpu":
raise unittest.SkipTest("TODO(b/268295912): ShardingRemover crash")
if poly is not None:
raise unittest.SkipTest("TODO: Sharding custom calls lack shape refinement")
mesh = Mesh(self.devices, axis_names=('x'))
a = np.arange(np.prod(4 * 4), dtype=np.float32).reshape((4, 4))
@ -706,7 +718,8 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
@tf.function(autograph=False, jit_compile=True)
def f_tf(a):
f_converted = jax2tf.convert(f_jax, native_serialization=True)
f_converted = jax2tf.convert(f_jax, native_serialization=True,
polymorphic_shapes=poly)
if jtu.device_under_test() == "tpu":
res = tf.compat.v1.tpu.rewrite(
f_converted, [tf.convert_to_tensor(a)],