mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Update JAX flag parsing to work when some elements of argv aren't absl friendly.
PiperOrigin-RevId: 389168058
This commit is contained in:
parent
aee61dab8a
commit
c94f41290c
@ -17,6 +17,7 @@
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
@ -147,9 +148,15 @@ class Config:
|
||||
def parse_flags_with_absl(self):
|
||||
global already_configured_with_absl
|
||||
if not already_configured_with_absl:
|
||||
# Extract just the --jax... flags (before the first --) from argv. In some
|
||||
# environments (e.g. ipython/colab) argv might be a mess of things
|
||||
# parseable by absl and other junk.
|
||||
jax_argv = itertools.takewhile(lambda a: a != '--', sys.argv)
|
||||
jax_argv = ['', *(a for a in jax_argv if a.startswith('--jax'))]
|
||||
|
||||
import absl.flags
|
||||
self.config_with_absl()
|
||||
absl.flags.FLAGS(sys.argv, known_only=True)
|
||||
absl.flags.FLAGS(jax_argv, known_only=True)
|
||||
self.complete_absl_config(absl.flags)
|
||||
already_configured_with_absl = True
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user