creat jax config api to allow custom pjrt client create option settings. this allows a device platform's pjrt client be aware of the calling (customer) ml framework

PiperOrigin-RevId: 638009713
This commit is contained in:
Yazhou Zu 2024-05-28 13:42:18 -07:00 committed by jax authors
parent db11842387
commit 91d68b5564
5 changed files with 59 additions and 15 deletions

View File

@ -389,6 +389,7 @@ pytype_strict_library(
name = "cloud_tpu_init",
srcs = ["_src/cloud_tpu_init.py"],
deps = [
":config",
":hardware_utils",
":version",
],

View File

@ -13,8 +13,9 @@
# limitations under the License.
import os
from jax._src import hardware_utils
from jax import version
from jax._src import config
from jax._src import hardware_utils
running_in_cloud_tpu_vm: bool = False
@ -73,3 +74,9 @@ def cloud_tpu_init() -> None:
# this makes tensorstore serialization work better on TPU
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60')
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES', '256')
if config.jax_pjrt_client_create_options.value is None:
config.update(
'jax_pjrt_client_create_options',
f'ml_framework_name:JAX;ml_framework_version:{version.__version__}'
)

View File

@ -935,6 +935,12 @@ jax_platforms = define_optional_string_state(
'otherwise.'
))
jax_pjrt_client_create_options = define_optional_string_state(
name='jax_pjrt_client_create_options',
default=None,
help=('A set of key-value pairs in the format of "k1:v1;k2:v2" strings '
'provided to a device platform pjrt client as extra arguments.'))
enable_checks = define_bool_state(
name='jax_enable_checks',
default=False,

View File

@ -47,6 +47,7 @@ from jax._src.cloud_tpu_init import maybe_import_libtpu
from jax._src.lib import cuda_versions
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.lib import jaxlib
logger = logging.getLogger(__name__)
@ -160,7 +161,13 @@ def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None:
t.start()
try:
client = xla_client.make_tpu_client(_get_tpu_library_path())
if xla_extension_version >= 267:
client = xla_client.make_tpu_client( # type: ignore
_get_tpu_library_path(),
_options_from_jax_configs("tpu"))
else:
client = xla_client.make_tpu_client(
_get_tpu_library_path())
finally:
t.cancel()
@ -618,16 +625,30 @@ def discover_pjrt_plugins() -> None:
def _options_from_jax_configs(plugin_name):
if plugin_name != "cuda":
return {}
options = {}
pjrt_client_options = config.jax_pjrt_client_create_options.value
pjrt_client_option_list = []
if pjrt_client_options:
pjrt_client_option_list = pjrt_client_options.split(";")
for option in pjrt_client_option_list:
option_list = option.split(":")
if (len(option_list) != 2):
raise RuntimeError(
"Multiple ':' separators for option in "
f"jax_pjrt_client_create_options: '{option}'. "
"Should be in format 'key:value'")
options[option_list[0]] = option_list[1]
if plugin_name == "cuda":
visible_devices = CUDA_VISIBLE_DEVICES.value
if visible_devices != 'all':
options['visible_devices'] = [int(x) for x in visible_devices.split(',')]
options['enable_mock_nccl'] = _USE_MOCK_GPU_CLIENT.value
if options['enable_mock_nccl']:
options['num_nodes'] = _MOCK_NUM_GPUS.value
return options

View File

@ -26,6 +26,7 @@ from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
config.parse_flags_with_absl()
@ -143,7 +144,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
def _mock_tpu_client(library_path=None):
def _mock_tpu_client_with_options(library_path=None, options=None):
time_to_wait = 5
start = time.time()
while not w:
@ -157,6 +158,14 @@ class XlaBridgeTest(jtu.JaxTestCase):
msg = str(w[-1].message)
self.assertIn("Did you run your code on all TPU hosts?", msg)
def _mock_tpu_client(library_path=None):
_mock_tpu_client_with_options(library_path=library_path, options=None)
if xla_extension_version >= 267:
with mock.patch.object(xc, "make_tpu_client",
side_effect=_mock_tpu_client_with_options):
xb.tpu_client_timer_callback(0.01)
else:
with mock.patch.object(xc, "make_tpu_client",
side_effect=_mock_tpu_client):
xb.tpu_client_timer_callback(0.01)