mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Enable all native lowering jax2tf tests
Filed bugs for the few remaining tests, and disabled them. Fixed the logging of the compiled HLO on test failure. PiperOrigin-RevId: 510135651
This commit is contained in:
parent
454e4de524
commit
a9e886f956
@ -127,6 +127,27 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
"top_k_sort_inf_nan_inshape=float32[5]_k=5" in harness.fullname):
|
||||
raise unittest.SkipTest("Unexplained failure, but in old no_jax_array")
|
||||
|
||||
if (config.jax2tf_default_experimental_native_lowering and
|
||||
"cholesky" in harness.fullname):
|
||||
raise unittest.SkipTest("b/269386856: cholesky failures")
|
||||
|
||||
if (config.jax2tf_default_experimental_native_lowering and
|
||||
device in ["cpu", "gpu"] and
|
||||
"igammac" in harness.fullname):
|
||||
raise unittest.SkipTest("b/269401509: igammac failures")
|
||||
|
||||
if (config.jax2tf_default_experimental_native_lowering and
|
||||
device == "tpu" and
|
||||
harness.dtype == jnp.bfloat16 and
|
||||
"eigh" in harness.fullname):
|
||||
raise unittest.SkipTest("b/269388842: eigh failures on TPU for bfloat16")
|
||||
|
||||
if (config.jax2tf_default_experimental_native_lowering and
|
||||
device == "gpu" and
|
||||
"lu" in harness.fullname):
|
||||
raise unittest.SkipTest("b/269388847: lu failures on GPU")
|
||||
|
||||
|
||||
associative_scan_reductions = harness.params.get("associative_scan_reductions", False)
|
||||
try:
|
||||
with jax.jax2tf_associative_scan_reductions(associative_scan_reductions):
|
||||
|
@ -1044,7 +1044,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
|
||||
polymorphic_shapes=["b1", "b2"],
|
||||
expect_error=(
|
||||
(None, None) if not config.jax2tf_default_experimental_native_lowering else
|
||||
(None, None) if (not config.jax2tf_default_experimental_native_lowering or
|
||||
not config.jax_jit_pjit_api_merge) else
|
||||
(ValueError,
|
||||
"The following dimension variables cannot be computed from the static shapes of the kept lowered arguments")))
|
||||
|
||||
@ -1056,7 +1057,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
|
||||
polymorphic_shapes=["b1", "b1 * b1"],
|
||||
expect_error=(
|
||||
(None, None) if not config.jax2tf_default_experimental_native_lowering else
|
||||
(None, None) if (not config.jax2tf_default_experimental_native_lowering or
|
||||
not config.jax_jit_pjit_api_merge) else
|
||||
(ValueError,
|
||||
"The following dimension variables cannot be computed from the static shapes of the kept lowered arguments")))
|
||||
|
||||
@ -1067,7 +1069,8 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
|
||||
polymorphic_shapes=["b1", "b2"],
|
||||
expect_error=(
|
||||
(None, None) if not config.jax2tf_default_experimental_native_lowering else
|
||||
(None, None) if (not config.jax2tf_default_experimental_native_lowering or
|
||||
not config.jax_jit_pjit_api_merge) else
|
||||
(KeyError,
|
||||
"Encountered dimension variable 'b1' that is not appearing in the shapes")))
|
||||
|
||||
|
@ -291,10 +291,11 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
|
||||
logging.info("[%s] Logging HLO for exception in mode %s: %s",
|
||||
self._testMethodName, mode, e)
|
||||
jax_comp = jax.xla_computation(func_jax)(*args)
|
||||
jax_hlo = jax_comp.as_hlo_text()
|
||||
jax_lowered = jax.jit(func_jax).lower(*args)
|
||||
# We log the HLO dialect for easier comparison with TF
|
||||
logging.info("[%s] JAX NON_OPT HLO\n%s",
|
||||
self._testMethodName, jax_hlo)
|
||||
self._testMethodName,
|
||||
jax_lowered.compiler_ir(dialect="hlo").as_hlo_text()) # type: ignore
|
||||
|
||||
tf_args_signature = _make_tf_input_signature(*args)
|
||||
# If we give the signature, we cannot pass scalars
|
||||
@ -313,7 +314,7 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
tf_hlo)
|
||||
|
||||
backend = xla_bridge.get_backend()
|
||||
modules = backend.compile(jax_comp).hlo_modules()
|
||||
modules = backend.compile(str(jax_lowered.compiler_ir())).hlo_modules()
|
||||
jax_opt_hlo = modules[0].to_string()
|
||||
logging.info("[%s] JAX OPT HLO\n%s", self._testMethodName,
|
||||
jax_opt_hlo)
|
||||
|
Loading…
x
Reference in New Issue
Block a user