Add support for int fields to compiler_options.

PiperOrigin-RevId: 549790380
This commit is contained in:
Parker Schuh 2023-07-20 17:35:57 -07:00 committed by jax authors
parent 60bb3bc5c5
commit dcad04d244
2 changed files with 6 additions and 1 deletions

View File

@ -242,6 +242,8 @@ def _hash_compile_options(hash_obj, compile_options_obj):
_hash_string(hash_obj, kv[1])
elif isinstance(kv[1], bool):
_hash_bool(hash_obj, kv[1])
elif isinstance(kv[1], int):
_hash_int(hash_obj, kv[1])
else:
raise RuntimeError("Invalid type: %s" % repr(type(kv[1])))

View File

@ -1254,6 +1254,8 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
@unittest.skipIf(xla_extension_version < 171,
'Test requires xla_extension_version >= 171')
def test_jit_lower_compile_with_compiler_options(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.
@ -1261,7 +1263,8 @@ class CPPJitTest(jtu.BufferDonationTestCase):
f_jit = self.jit(f)
lowered = f_jit.lower(1.)
lowered.compile( # doesn't crash
compiler_options={"xla_embed_ir_in_executable": True})
compiler_options={"xla_embed_ir_in_executable": True,
"xla_dump_max_hlo_modules": 200})
def test_jit_lower_compile_with_compiler_options_invalid(self):
def f(x):