mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
creat jax config api to allow custom pjrt client create option settings. this allows a device platform's pjrt client be aware of the calling (customer) ml framework
PiperOrigin-RevId: 638009713
This commit is contained in:
parent
db11842387
commit
91d68b5564
@ -389,6 +389,7 @@ pytype_strict_library(
|
||||
name = "cloud_tpu_init",
|
||||
srcs = ["_src/cloud_tpu_init.py"],
|
||||
deps = [
|
||||
":config",
|
||||
":hardware_utils",
|
||||
":version",
|
||||
],
|
||||
|
@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from jax._src import hardware_utils
|
||||
from jax import version
|
||||
from jax._src import config
|
||||
from jax._src import hardware_utils
|
||||
|
||||
running_in_cloud_tpu_vm: bool = False
|
||||
|
||||
@ -73,3 +74,9 @@ def cloud_tpu_init() -> None:
|
||||
# this makes tensorstore serialization work better on TPU
|
||||
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60')
|
||||
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES', '256')
|
||||
|
||||
if config.jax_pjrt_client_create_options.value is None:
|
||||
config.update(
|
||||
'jax_pjrt_client_create_options',
|
||||
f'ml_framework_name:JAX;ml_framework_version:{version.__version__}'
|
||||
)
|
||||
|
@ -935,6 +935,12 @@ jax_platforms = define_optional_string_state(
|
||||
'otherwise.'
|
||||
))
|
||||
|
||||
jax_pjrt_client_create_options = define_optional_string_state(
|
||||
name='jax_pjrt_client_create_options',
|
||||
default=None,
|
||||
help=('A set of key-value pairs in the format of "k1:v1;k2:v2" strings '
|
||||
'provided to a device platform pjrt client as extra arguments.'))
|
||||
|
||||
enable_checks = define_bool_state(
|
||||
name='jax_enable_checks',
|
||||
default=False,
|
||||
|
@ -47,6 +47,7 @@ from jax._src.cloud_tpu_init import maybe_import_libtpu
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import jaxlib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -160,7 +161,13 @@ def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None:
|
||||
t.start()
|
||||
|
||||
try:
|
||||
client = xla_client.make_tpu_client(_get_tpu_library_path())
|
||||
if xla_extension_version >= 267:
|
||||
client = xla_client.make_tpu_client( # type: ignore
|
||||
_get_tpu_library_path(),
|
||||
_options_from_jax_configs("tpu"))
|
||||
else:
|
||||
client = xla_client.make_tpu_client(
|
||||
_get_tpu_library_path())
|
||||
finally:
|
||||
t.cancel()
|
||||
|
||||
@ -618,16 +625,30 @@ def discover_pjrt_plugins() -> None:
|
||||
|
||||
|
||||
def _options_from_jax_configs(plugin_name):
|
||||
if plugin_name != "cuda":
|
||||
return {}
|
||||
|
||||
options = {}
|
||||
|
||||
pjrt_client_options = config.jax_pjrt_client_create_options.value
|
||||
pjrt_client_option_list = []
|
||||
if pjrt_client_options:
|
||||
pjrt_client_option_list = pjrt_client_options.split(";")
|
||||
|
||||
for option in pjrt_client_option_list:
|
||||
option_list = option.split(":")
|
||||
if (len(option_list) != 2):
|
||||
raise RuntimeError(
|
||||
"Multiple ':' separators for option in "
|
||||
f"jax_pjrt_client_create_options: '{option}'. "
|
||||
"Should be in format 'key:value'")
|
||||
options[option_list[0]] = option_list[1]
|
||||
|
||||
if plugin_name == "cuda":
|
||||
visible_devices = CUDA_VISIBLE_DEVICES.value
|
||||
if visible_devices != 'all':
|
||||
options['visible_devices'] = [int(x) for x in visible_devices.split(',')]
|
||||
options['enable_mock_nccl'] = _USE_MOCK_GPU_CLIENT.value
|
||||
if options['enable_mock_nccl']:
|
||||
options['num_nodes'] = _MOCK_NUM_GPUS.value
|
||||
|
||||
return options
|
||||
|
||||
|
||||
|
@ -26,6 +26,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -143,7 +144,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
def _mock_tpu_client(library_path=None):
|
||||
def _mock_tpu_client_with_options(library_path=None, options=None):
|
||||
time_to_wait = 5
|
||||
start = time.time()
|
||||
while not w:
|
||||
@ -157,6 +158,14 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
msg = str(w[-1].message)
|
||||
self.assertIn("Did you run your code on all TPU hosts?", msg)
|
||||
|
||||
def _mock_tpu_client(library_path=None):
|
||||
_mock_tpu_client_with_options(library_path=library_path, options=None)
|
||||
|
||||
if xla_extension_version >= 267:
|
||||
with mock.patch.object(xc, "make_tpu_client",
|
||||
side_effect=_mock_tpu_client_with_options):
|
||||
xb.tpu_client_timer_callback(0.01)
|
||||
else:
|
||||
with mock.patch.object(xc, "make_tpu_client",
|
||||
side_effect=_mock_tpu_client):
|
||||
xb.tpu_client_timer_callback(0.01)
|
||||
|
Loading…
x
Reference in New Issue
Block a user