[jax2tf] Simplifies model testing file writing logic.

PiperOrigin-RevId: 500984975
This commit is contained in:
Marc van Zee 2023-01-10 07:07:16 -08:00 committed by jax authors
parent 62f2b9680b
commit 10847a9372

View File

@ -111,21 +111,21 @@ def _write_markdown(results: Dict[str, List[Tuple[str, str,]]]) -> None:
error_lines.append(f"## `{example_name}`")
error_lines.extend(example_error_lines)
# TODO(marcvanzee): This is somewhat brittle, consider rewriting it.
g3doc_path = "../g3doc"
output_path = os.path.join(g3doc_path, "convert_models_results.md")
template_path = output_path + ".template"
template = "".join(open(template_path).readlines())
template = template.replace("{{generation_date}}", str(datetime.date.today()))
template = template.replace("{{table}}", "\n".join(table_lines))
template = template.replace("{{errors}}", "\n".join(error_lines))
if (workdir := "BUILD_WORKING_DIRECTORY") in os.environ:
os.chdir(os.path.dirname(os.environ[workdir]))
with tf.io.gfile.GFile(output_path, "w") as f:
f.write(template)
with tf.io.gfile.GFile(template_path, "r") as f_in, \
tf.io.gfile.GFile(output_path, "w") as f_out:
template = "".join(f_in.readlines())
template = template.replace("{{generation_date}}", str(datetime.date.today()))
template = template.replace("{{table}}", "\n".join(table_lines))
template = template.replace("{{errors}}", "\n".join(error_lines))
f_out.write(template)
print("Written converter results to", output_path)