Update JAX flag parsing to work when some elements of argv aren't absl friendly.

PiperOrigin-RevId: 389168058
This commit is contained in:
Tom Hennigan 2021-08-06 07:05:43 -07:00 committed by jax authors
parent aee61dab8a
commit c94f41290c

View File

@ -17,6 +17,7 @@
import contextlib
import functools
import itertools
import os
import sys
import threading
@ -147,9 +148,15 @@ class Config:
def parse_flags_with_absl(self):
global already_configured_with_absl
if not already_configured_with_absl:
# Extract just the --jax... flags (before the first --) from argv. In some
# environments (e.g. ipython/colab) argv might be a mess of things
# parseable by absl and other junk.
jax_argv = itertools.takewhile(lambda a: a != '--', sys.argv)
jax_argv = ['', *(a for a in jax_argv if a.startswith('--jax'))]
import absl.flags
self.config_with_absl()
absl.flags.FLAGS(sys.argv, known_only=True)
absl.flags.FLAGS(jax_argv, known_only=True)
self.complete_absl_config(absl.flags)
already_configured_with_absl = True