mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
#sdy Add JAX backwards compatibility test.
This tests saving a module with one set of axis names, but loading it with another set of axis names. This does also test the custom calls: - `@Sharding` - `@xla.sdy.GlobalToLocalShape` - `@xla.sdy.LocalToGlobalShape` But note that there are a bunch of other custom calls that will be tested in the Shardy and XLA codebases. The way the testing utils is tested here doesn't allow me to set `out_shardings` for example. So JAX can rely on the existence of those tests as stability guarantees just like for StableHLO. PiperOrigin-RevId: 732893432
This commit is contained in:
parent
ac493655bf
commit
ed4a7bbab1
@ -0,0 +1,57 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32
|
||||
|
||||
|
||||
# Pasted from the test output (see export_back_compat_test_util.py module docstring)
|
||||
data_2025_02_12 = dict(
|
||||
testdata_version=1,
|
||||
platform='tpu',
|
||||
custom_call_targets=['Sharding', 'xla.sdy.GlobalToLocalShape', 'xla.sdy.LocalToGlobalShape'],
|
||||
serialized_date=datetime.date(2025, 2, 12),
|
||||
inputs=(array([[0., 1., 2., 3.],
|
||||
[4., 5., 6., 7.]], dtype=float32),),
|
||||
expected_outputs=(array([[4., 5., 6., 7.],
|
||||
[0., 1., 2., 3.]], dtype=float32),),
|
||||
mlir_module_text=r"""
|
||||
#loc1 = loc("x")
|
||||
#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":1052:13)
|
||||
#loc6 = loc("jit(func)/jit(main)/shard_map"(#loc3))
|
||||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\\\22a\\\22=2]>}"}, mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<2x4xf32> loc("x")) -> (tensor<2x4xf32> {jax.result_info = ""}) {
|
||||
%0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>"}, mhlo.sharding = "{devices=[2,1]<=[2]}"} : (tensor<2x4xf32>) -> tensor<2x4xf32> loc(#loc5)
|
||||
%1 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%0) : (tensor<2x4xf32>) -> tensor<1x4xf32> loc(#loc6)
|
||||
%2 = call @xla.sdy.manual_computation_body(%1) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy<manual_axes{\\\22a\\\22}>", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>"}} : (tensor<1x4xf32>) -> tensor<1x4xf32> loc(#loc6)
|
||||
%3 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%2) : (tensor<1x4xf32>) -> tensor<2x4xf32> loc(#loc6)
|
||||
return %3 : tensor<2x4xf32> loc(#loc)
|
||||
} loc(#loc)
|
||||
func.func @xla.sdy.manual_computation_body(%arg0: tensor<1x4xf32> loc("jit(func)/jit(main)/shard_map"(#loc3))) -> tensor<1x4xf32> {
|
||||
%0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>, source_target_pairs = dense<[[0, 1], [1, 0]]> : tensor<2x2xi64>}> : (tensor<1x4xf32>) -> tensor<1x4xf32> loc(#loc7)
|
||||
return %0 : tensor<1x4xf32> loc(#loc6)
|
||||
} loc(#loc6)
|
||||
} loc(#loc)
|
||||
#loc = loc(unknown)
|
||||
#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":1051:10)
|
||||
#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":1050:15)
|
||||
#loc5 = loc("jit(func)/jit(main)/sharding_constraint"(#loc2))
|
||||
#loc7 = loc("jit(func)/jit(main)/ppermute"(#loc4))
|
||||
""",
|
||||
mlir_module_serialized=b'ML\xefR\rStableHLO_v1.8.8\x00\x01\x1d\x05\x01\x05\r\x01\x03\x0b\x03\x0b\x0f\x13\x17\x1b\x1f\x03\x97q\x13\x019\x0f\x07\x0b\x0b+\x0b\x0f\x13\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0f\x0b\x17\x0f\x0b\x1b\x0b\x0f\x0b\x17\x13\x039\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0f\x8f\x13\x0b\x0b\x0b\x0b#\x0b\x0b\x0b\x0b\x0b\x01\x05\x0f\x0b\x03\x0f\x17\x17\x07\x07\x17\x17\x17\x02v\x03\x1d\x1f!\x1f\x05\x11\x05\x13\x03\t\x0b\r\x05\x0f\x15\x17\x19\x1b\x05\x15\x11\x03\x00\x03\x03\x11\x13\x05\x17\x05\x19\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\x05!\x17\x07r\x10\x1b\x1d%\'\x05#\x17\x07j\x10\x1f\x1d+\x03\x05%\x03\x05\x05[/_\x05\'\x1d35\x05)\x17\x07n\x10\x15\x03\x03\x05e\x03\x01\x1d+\x1d-\x0b\x03\x05\x01\x1d/\x03\x03G\r\x01#\r\x03\x03M\r\x03O;\x1d1\x1d3\x1d5#\x0f\x13\x0b\x05\x1f\x11A\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x03]=\x1d7\x1d9\x1d;\x1d=\r\x07g=ikm=\x1d?\x1dA\x1dC\x1dE\x1dG\x01\x02\x02\x01\t)\x05\x05\x11\t)\x05\t\x11\t\t\x1d\x11\x03\x07\x03\x07\x11\x03\x05\x03\x05)\x05\t\t\x0b\x04\xb9\x05\x01Q\x03\t\x01\x07\x04\xa7\x03\x01\t\x05P\x03\x03\x07\x04]\x03\x0b\x17\x03\x0f)\x00\x03G1-\x05\x03\x07\x03\x01\x03F\x01\x07\x03\x05\x03\x03\x0bG\x017\t\x03\x05\x03\x05\x03F\x01\x0b\x03\x07\x03\x07\x07\x04\x03\x03\t\x05P\x01\r\x07\x04)\x03\x05\x0b\x03\x0b\x01\x00\tF#\x0f\x03\x05\x03\x01\x07\x04\x01\x03\x03\x06\x03\x01\x05\x01\x00r\x0bI7-3)+7\x13+#\x0f\x0b!Ae\x03Q\x1d\x05;=\x13%)=\x1f9i3\x11-\x15\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00return_v1\x00collective_permute_v1\x00call_v1\x00mhlo.frontend_attributes\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00xla.sdy.meshes\x00{mesh = #sdy.mesh<[\\"a\\"=2]>}\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/shard_map\x00jit(func)/jit(main)/ppermute\x00x\x00mhlo.sharding\x00jit(func)/jit(main)/sharding_constraint\x00\x00#sdy.sharding_per_value<[<@mesh, [{\\"a\\"}, {}]>]>\x00xla.sdy.manual_computation_body\x00jax.result_info\x00main\x00public\x00xla.sdy.sharding\x00{devices=[2,1]<=[2]}\x00Sharding\x00xla.sdy.GlobalToLocalShape\x00xla.sdy.in_shardings\x00xla.sdy.manual_axes\x00#sdy<manual_axes{\\"a\\"}>\x00xla.sdy.out_shardings\x00xla.sdy.LocalToGlobalShape\x00\x08a\x11\x05;\x01\x0bEIKQS\x11?;a9A999\x11?;c9A999\x03C\x11?;o9A999\x0b9U9C;\x05WY',
|
||||
xla_call_module_version=9,
|
||||
nr_devices=2,
|
||||
) # End paste
|
@ -57,6 +57,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxT
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Qr
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Sharding
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import tpu_stablehlo_dynamic_reduce_window
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import shardy_sharding_ops_with_different_meshes
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_rng_bit_generator
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_top_k
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_approx_top_k
|
||||
@ -67,6 +68,7 @@ import jax.numpy as jnp
|
||||
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.sharding import NamedSharding as NS
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
@ -1024,5 +1026,34 @@ class CompatTest(bctu.CompatTestBase):
|
||||
)
|
||||
|
||||
|
||||
@jtu.with_config(jax_use_shardy_partitioner=True)
|
||||
class ShardyCompatTest(bctu.CompatTestBase):
|
||||
def test_shardy_sharding_ops_with_different_meshes(self):
|
||||
# Tests whether we can save and load a module with meshes that have the
|
||||
# same axis sizes (and same order) but different axis names.
|
||||
# Also tests "Sharding", "xla.sdy.GlobalToLocalShape",
|
||||
# "xla.sdy.LocalToGlobalShape".
|
||||
if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2:
|
||||
self.skipTest("Test runs only on TPU with at least 2 devices")
|
||||
|
||||
# Must use exactly 2 devices for expected outputs from ppermute.
|
||||
devices = jax.devices()[:2]
|
||||
old_mesh = Mesh(devices, axis_names=('a'))
|
||||
|
||||
def func(x): # x: f32[4, 4]
|
||||
@partial(shard_map, mesh=old_mesh,
|
||||
in_specs=(P('a', None),), out_specs=P('a', None))
|
||||
def shard_map_func(x): # b: f32[2, 4]
|
||||
axis_size = lax.psum(1, 'a')
|
||||
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
|
||||
return lax.ppermute(x, 'a', perm=perm)
|
||||
x = jax.lax.with_sharding_constraint(x, NS(old_mesh, P('a', None)))
|
||||
return shard_map_func(x)
|
||||
|
||||
data = self.load_testdata(shardy_sharding_ops_with_different_meshes.data_2025_02_12)
|
||||
with Mesh(devices, axis_names=('x')):
|
||||
self.run_one_test(func, data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user