mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Make the regex checking of export_tests less strict
PiperOrigin-RevId: 582704122
This commit is contained in:
parent
5c3da219c0
commit
118d85cd6c
@ -705,12 +705,12 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
expected_re = re.compile(
|
||||
# The top-level input it replicated
|
||||
r"func.func .* @main\(%arg0: tensor<16x4xf32> {mhlo.sharding = \"{replicated}\"}\).*"
|
||||
r"func.func .* @main\(%arg0: tensor<16x4xf32>.*mhlo.sharding = \"{replicated}\"}\).*"
|
||||
# We apply the in_shardings for f_jax
|
||||
r".*custom_call @Sharding\(%arg0\) {mhlo.sharding = \"{devices=\[2,1\]<=\[2\]}\"}.*"
|
||||
r".*custom_call @Sharding\(%arg0\).*mhlo.sharding = \"{devices=\[2,1\]<=\[2\]}\"}.*"
|
||||
r"%1 = .*call @call_exported_f_jax.*"
|
||||
# We apply the out_shardings for f_jax
|
||||
r".*custom_call @Sharding\(%1\) {mhlo.sharding = \"{devices=\[1,2\]<=\[2\]}\"}.*",
|
||||
r".*custom_call @Sharding\(%1\).*mhlo.sharding = \"{devices=\[1,2\]<=\[2\]}\"}.*",
|
||||
re.DOTALL)
|
||||
hlo = jax.jit(export.call_exported(exp)).lower(a_device).as_text()
|
||||
self.assertRegex(hlo, expected_re)
|
||||
@ -771,8 +771,8 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
primal_out_sharding = "{replicated}"
|
||||
|
||||
main = re.search(
|
||||
r"func.func public @main\(%arg0: tensor<10x20xf32> {mhlo.sharding = \"([^\"]+)\""
|
||||
r".*%arg1: tensor<20x10xf32> {mhlo.sharding = \"([^\"]+)\""
|
||||
r"func.func public @main\(%arg0: tensor<10x20xf32>.*mhlo.sharding = \"([^\"]+)\""
|
||||
r".*%arg1: tensor<20x10xf32>.*mhlo.sharding = \"([^\"]+)\""
|
||||
# result
|
||||
r".* -> \(tensor<10x20xf32>.*mhlo.sharding = \"([^\"]+)\"",
|
||||
vjp_module_str)
|
||||
|
Loading…
x
Reference in New Issue
Block a user