mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add support for int fields to compiler_options.
PiperOrigin-RevId: 549790380
This commit is contained in:
parent
60bb3bc5c5
commit
dcad04d244
@ -242,6 +242,8 @@ def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
_hash_string(hash_obj, kv[1])
|
||||
elif isinstance(kv[1], bool):
|
||||
_hash_bool(hash_obj, kv[1])
|
||||
elif isinstance(kv[1], int):
|
||||
_hash_int(hash_obj, kv[1])
|
||||
else:
|
||||
raise RuntimeError("Invalid type: %s" % repr(type(kv[1])))
|
||||
|
||||
|
@ -1254,6 +1254,8 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 171,
|
||||
'Test requires xla_extension_version >= 171')
|
||||
def test_jit_lower_compile_with_compiler_options(self):
|
||||
def f(x):
|
||||
return jnp.sqrt(x ** 2) + 1.
|
||||
@ -1261,7 +1263,8 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
f_jit = self.jit(f)
|
||||
lowered = f_jit.lower(1.)
|
||||
lowered.compile( # doesn't crash
|
||||
compiler_options={"xla_embed_ir_in_executable": True})
|
||||
compiler_options={"xla_embed_ir_in_executable": True,
|
||||
"xla_dump_max_hlo_modules": 200})
|
||||
|
||||
def test_jit_lower_compile_with_compiler_options_invalid(self):
|
||||
def f(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user