mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix tests that were adding __annotations__ for some reason to CompileOptions.
PiperOrigin-RevId: 515133225
This commit is contained in:
parent
9c4db8c962
commit
942e79ffe3
@ -128,11 +128,17 @@ def _hash_computation(hash_obj, xla_computation):
|
||||
hash_obj.update(scrubbed_hlo)
|
||||
|
||||
def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
expected_num_compile_options = 37
|
||||
assert len(dir(compile_options_obj)) == expected_num_compile_options, (
|
||||
f"Unexpected number of CompileOption fields: "
|
||||
f"{len(dir(compile_options_obj))}. This likely: means that an extra "
|
||||
f"field was added, and this function needs to be updated.")
|
||||
expected_num_compile_options = 11
|
||||
# Ignore private and built-in methods. These can unexpectedly change and lead
|
||||
# to false positives, e.g. when different Python versions include different
|
||||
# built-ins.
|
||||
num_actual_options = len(
|
||||
[x for x in dir(compile_options_obj) if not x.startswith("_")])
|
||||
assert num_actual_options == expected_num_compile_options, (
|
||||
"Unexpected number of CompileOption fields: "
|
||||
f"{num_actual_options}. This likely: means that an extra "
|
||||
"field was added, and this function needs to be updated."
|
||||
)
|
||||
|
||||
if compile_options_obj.argument_layouts is not None:
|
||||
map(lambda shape: hash_obj.update(shape.to_serialized_proto()),
|
||||
|
Loading…
x
Reference in New Issue
Block a user