Use epath from etils package. This CL also makes epath a required dep for JAX.

This is being used in the following ways in this CL:

* To dump IR, you can now pass paths with `gs://` or `cns` and the HLO can be dumped to those paths.
* Removing the TF dep from gda serialization.

PiperOrigin-RevId: 452117007
This commit is contained in:
Yash Katariya 2022-05-31 12:46:54 -07:00 committed by jax authors
parent f2add8783a
commit e0ff842c2a
3 changed files with 6 additions and 5 deletions

View File

@ -53,6 +53,7 @@ from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
import jax._src.util as util
from jax._src.util import flatten, unflatten
from etils import epath
FLAGS = flags.FLAGS
@ -810,9 +811,8 @@ def _make_string_safe_for_filename(s: str) -> str:
def _dump_ir_to_file(name: str, ir: str):
id = next(_ir_dump_counter)
name = f"jax_ir{id}_{_make_string_safe_for_filename(name)}.mlir"
name = os.path.join(FLAGS.jax_dump_ir_to, name)
with open(name, "w") as f:
f.write(ir)
name = epath.Path(FLAGS.jax_dump_ir_to) / name
name.write_text(ir)
def compile_or_get_cached(backend, computation, compile_options):

View File

@ -28,7 +28,7 @@ from jax.experimental.maps import Mesh
import jax.numpy as jnp
import numpy as np
import tensorstore as ts
import tensorflow.compat.v2 as tf
from etils import epath
TS_CONTEXT = ts.Context({'file_io_concurrency': {'limit': 128}})
@ -256,7 +256,7 @@ class GlobalAsyncCheckpointManager:
if current_process == 0:
logging.info('Renaming %s to %s', temp_checkpoint_dir, final_checkpoint_dir)
tf.io.gfile.rename(temp_checkpoint_dir, final_checkpoint_dir)
epath.Path(temp_checkpoint_dir).rename(final_checkpoint_dir)
logging.info('Finished saving GDA checkpoint to `%s`.', final_checkpoint_dir)
self._client.key_value_set(_get_key(self._final_ckpt_dir), _CHECKPOINT_SUCCESS)
except Exception as e:

View File

@ -44,6 +44,7 @@ setup(
'opt_einsum',
'scipy>=1.2.1',
'typing_extensions',
'etils[epath]'
],
extras_require={
# Minimum jaxlib version; used in testing.