mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Add more sharding tests with shape polymorphism
PiperOrigin-RevId: 521471546
This commit is contained in:
parent
ff313a37a2
commit
05249ec770
@ -11,7 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for the jax2tf conversion of pjit."""
|
||||
"""Tests for the jax2tf conversion of pjit.
|
||||
|
||||
To verify that the tests do run indeed on multiple devices you can run
|
||||
|
||||
perftools/gputools/profiler/jfprof.sh jax/experimental/jax2tf/tests:sharding_test_tpu -- -c opt --test_filter=ShardingTest.test_shmap_all_to_all --test_arg=--vmodule=jax2tf=3 --
|
||||
|
||||
"""
|
||||
import contextlib
|
||||
from functools import partial
|
||||
import logging
|
||||
@ -72,10 +78,6 @@ def tearDownModule():
|
||||
|
||||
class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
"""Tests that inspect the HLO for the sharding annotations.
|
||||
|
||||
To verify that the tests do run indeed on multiple devices you can run
|
||||
|
||||
perftools/gputools/profiler/jfprof.sh jax/experimental/jax2tf/tests:sharding_test_tpu -- -c opt --test_filter=ShardingTest.test_shmap_all_to_all --test_arg=--vmodule=jax2tf=3 --
|
||||
"""
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -278,7 +280,6 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
(r"custom_call_target.*Sharding", 3)
|
||||
])
|
||||
|
||||
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_pjit_closed_over_const(self):
|
||||
x = np.ones((10, 20), dtype=np.float32)
|
||||
@ -316,18 +317,21 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(res_tf, res_jax)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_nested_pjit={nested_pjit}_constraint={constraint=}",
|
||||
nested_pjit=nested_pjit)
|
||||
dict(testcase_name=f"_nested_pjit={nested_pjit}_constraint={constraint=}_poly={poly}",
|
||||
nested_pjit=nested_pjit, constraint=constraint, poly=poly)
|
||||
# We add a constraint either with a nested pjit or with a sharding_constraint
|
||||
for nested_pjit in (True, False)
|
||||
for constraint in (None, "P")
|
||||
for poly in (None, "b1,_", "_,b2", "b1,b2")
|
||||
)
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_pjit_sharding_constraint(self, nested_pjit=True, constraint="P"):
|
||||
def test_pjit_sharding_constraint(self, nested_pjit=True, constraint="P", poly=None):
|
||||
if poly is not None:
|
||||
raise unittest.SkipTest("TODO: Sharding custom calls lack shape refinement")
|
||||
constraint_sharding = P("x", None) if constraint == "P" else None
|
||||
@partial(pjit.pjit, in_shardings=None,
|
||||
out_shardings=None)
|
||||
def f_jax(x): # x: f32[10, 20]
|
||||
def f_jax(x): # x: f32[10, 20], optionally some axes as polymorphic
|
||||
y = jnp.concatenate([x, x], axis=1) # y: f32[10, 40]
|
||||
if nested_pjit:
|
||||
y = pjit.pjit(lambda y: y, in_shardings=constraint_sharding,
|
||||
@ -340,10 +344,11 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
self.log_jax_hlo(f_jax, [x], num_partitions=2)
|
||||
f_tf = jax2tf.convert(f_jax)
|
||||
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=poly)
|
||||
|
||||
# If we use a pjit then we see two constraints, otherwise only 1
|
||||
count_inner_sharding = 2 if nested_pjit else 1
|
||||
count_inner_sharding = (2 if nested_pjit else 1) if constraint == "P" else 0
|
||||
count_inner_replicated = (2 if nested_pjit else 1) if constraint != "P" else 0
|
||||
self.check_sharding(
|
||||
f_tf, [x],
|
||||
checks=[
|
||||
@ -352,13 +357,14 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
# The y argument
|
||||
(r"f32\[10,40\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]",
|
||||
count_inner_sharding),
|
||||
(r"f32\[10,40\].*custom_call_target.*Sharding.*sharding.*replicated",
|
||||
count_inner_replicated),
|
||||
# The output sharding
|
||||
(r"f32\[10,80\].*custom_call_target.*Sharding.*sharding.*replicated", 1),
|
||||
# No other annotations
|
||||
(r"custom_call_target.*Sharding", 2 + count_inner_sharding)
|
||||
(r"custom_call_target.*Sharding", 2 + count_inner_sharding + count_inner_replicated)
|
||||
])
|
||||
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}",
|
||||
in_shardings=in_shardings, out_shardings=out_shardings)
|
||||
@ -660,7 +666,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
# jax2tf.convert(f_jax, native_serialization=True), [a],
|
||||
# checks=[])
|
||||
|
||||
@unittest.skip("TODO(b/268295912): ShardingRemover crash")
|
||||
@unittest.skip("TODO(b/268295912): ShardingRemover crash,on all platforms!!!")
|
||||
def test_repro_xla_bug_shmap_collective_permute(self):
|
||||
mesh = Mesh(self.devices, axis_names=('x'))
|
||||
|
||||
@ -689,9 +695,15 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
res_tf = f_tf(a)
|
||||
self.assertAllClose(res_tf, expected)
|
||||
|
||||
def test_shmap_collective_permute(self):
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_poly={poly}", poly=poly)
|
||||
for poly in (None, "b1,_", "_,b2", "b1,b2")
|
||||
)
|
||||
def test_shmap_collective_permute(self, poly=None):
|
||||
if jtu.device_under_test() == "cpu":
|
||||
raise unittest.SkipTest("TODO(b/268295912): ShardingRemover crash")
|
||||
if poly is not None:
|
||||
raise unittest.SkipTest("TODO: Sharding custom calls lack shape refinement")
|
||||
mesh = Mesh(self.devices, axis_names=('x'))
|
||||
a = np.arange(np.prod(4 * 4), dtype=np.float32).reshape((4, 4))
|
||||
|
||||
@ -706,7 +718,8 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@tf.function(autograph=False, jit_compile=True)
|
||||
def f_tf(a):
|
||||
f_converted = jax2tf.convert(f_jax, native_serialization=True)
|
||||
f_converted = jax2tf.convert(f_jax, native_serialization=True,
|
||||
polymorphic_shapes=poly)
|
||||
if jtu.device_under_test() == "tpu":
|
||||
res = tf.compat.v1.tpu.rewrite(
|
||||
f_converted, [tf.convert_to_tensor(a)],
|
||||
|
Loading…
x
Reference in New Issue
Block a user