Make the regex checking of export_tests less strict

PiperOrigin-RevId: 582704122
This commit is contained in:
Yash Katariya 2023-11-15 09:24:07 -08:00 committed by jax authors
parent 5c3da219c0
commit 118d85cd6c

View File

@ -705,12 +705,12 @@ class JaxExportTest(jtu.JaxTestCase):
expected_re = re.compile(
# The top-level input it replicated
r"func.func .* @main\(%arg0: tensor<16x4xf32> {mhlo.sharding = \"{replicated}\"}\).*"
r"func.func .* @main\(%arg0: tensor<16x4xf32>.*mhlo.sharding = \"{replicated}\"}\).*"
# We apply the in_shardings for f_jax
r".*custom_call @Sharding\(%arg0\) {mhlo.sharding = \"{devices=\[2,1\]<=\[2\]}\"}.*"
r".*custom_call @Sharding\(%arg0\).*mhlo.sharding = \"{devices=\[2,1\]<=\[2\]}\"}.*"
r"%1 = .*call @call_exported_f_jax.*"
# We apply the out_shardings for f_jax
r".*custom_call @Sharding\(%1\) {mhlo.sharding = \"{devices=\[1,2\]<=\[2\]}\"}.*",
r".*custom_call @Sharding\(%1\).*mhlo.sharding = \"{devices=\[1,2\]<=\[2\]}\"}.*",
re.DOTALL)
hlo = jax.jit(export.call_exported(exp)).lower(a_device).as_text()
self.assertRegex(hlo, expected_re)
@ -771,8 +771,8 @@ class JaxExportTest(jtu.JaxTestCase):
primal_out_sharding = "{replicated}"
main = re.search(
r"func.func public @main\(%arg0: tensor<10x20xf32> {mhlo.sharding = \"([^\"]+)\""
r".*%arg1: tensor<20x10xf32> {mhlo.sharding = \"([^\"]+)\""
r"func.func public @main\(%arg0: tensor<10x20xf32>.*mhlo.sharding = \"([^\"]+)\""
r".*%arg1: tensor<20x10xf32>.*mhlo.sharding = \"([^\"]+)\""
# result
r".* -> \(tensor<10x20xf32>.*mhlo.sharding = \"([^\"]+)\"",
vjp_module_str)