Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
# Copyright 2023 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import contextlib
|
|
|
|
import io
|
|
|
|
import logging
|
2024-01-04 13:46:50 -08:00
|
|
|
import platform
|
2024-09-05 18:22:15 -07:00
|
|
|
import re
|
|
|
|
import shlex
|
2023-12-14 15:21:07 -08:00
|
|
|
import subprocess
|
|
|
|
import sys
|
Introduce hermetic CUDA in Google ML projects.
1) Hermetic CUDA rules allow building wheels with GPU support on a machine without GPUs, as well as running Bazel GPU tests on a machine with only GPUs and NVIDIA driver installed. When `--config=cuda` is provided in Bazel options, Bazel will download CUDA, CUDNN and NCCL redistributions in the cache, and use them during build and test phases.
[Default location of CUNN redistributions](https://developer.download.nvidia.com/compute/cudnn/redist/)
[Default location of CUDA redistributions](https://developer.download.nvidia.com/compute/cuda/redist/)
[Default location of NCCL redistributions](https://pypi.org/project/nvidia-nccl-cu12/#history)
2) To include hermetic CUDA rules in your project, add the following in the WORKSPACE of the downstream project dependent on XLA.
Note: use `@local_tsl` instead of `@tsl` in Tensorflow project.
```
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",
)
cuda_json_init_repository()
load(
"@cuda_redist_json//:distributions.bzl",
"CUDA_REDISTRIBUTIONS",
"CUDNN_REDISTRIBUTIONS",
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
"cuda_redist_init_repositories",
"cudnn_redist_init_repository",
)
cuda_redist_init_repositories(
cuda_redistributions = CUDA_REDISTRIBUTIONS,
)
cudnn_redist_init_repository(
cudnn_redistributions = CUDNN_REDISTRIBUTIONS,
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
"cuda_configure",
)
cuda_configure(name = "local_config_cuda")
load(
"@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
"nccl_redist_init_repository",
)
nccl_redist_init_repository()
load(
"@tsl//third_party/nccl/hermetic:nccl_configure.bzl",
"nccl_configure",
)
nccl_configure(name = "local_config_nccl")
```
PiperOrigin-RevId: 662981325
2024-08-14 10:57:53 -07:00
|
|
|
import tempfile
|
2023-12-14 15:21:07 -08:00
|
|
|
import textwrap
|
2024-01-04 13:46:50 -08:00
|
|
|
import unittest
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
|
|
|
|
import jax
|
|
|
|
import jax._src.test_util as jtu
|
2023-12-14 15:21:07 -08:00
|
|
|
from jax._src import xla_bridge
|
2024-11-05 13:28:17 -08:00
|
|
|
from jax._src.logging_config import _default_TF_CPP_MIN_LOG_LEVEL
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
|
|
|
|
# Note: importing absltest causes an extra absl root log handler to be
|
|
|
|
# registered, which causes extra debug log messages. We don't expect users to
|
|
|
|
# import absl logging, so it should only affect this test. We need to use
|
|
|
|
# absltest.main and config.parse_flags_with_absl() in order for jax_test flag
|
|
|
|
# parsing to work correctly with bazel (otherwise we could avoid importing
|
|
|
|
# absltest/absl logging altogether).
|
|
|
|
from absl.testing import absltest
|
2024-04-11 13:23:27 -07:00
|
|
|
jax.config.parse_flags_with_absl()
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
|
|
|
|
|
2024-05-29 13:38:33 -07:00
|
|
|
@contextlib.contextmanager
|
|
|
|
def jax_debug_log_modules(value):
|
|
|
|
# jax_debug_log_modules doesn't have a context manager, because it's
|
|
|
|
# not thread-safe. But since tests are always single-threaded, we
|
|
|
|
# can define one here.
|
|
|
|
original_value = jax.config.jax_debug_log_modules
|
|
|
|
jax.config.update("jax_debug_log_modules", value)
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
jax.config.update("jax_debug_log_modules", original_value)
|
|
|
|
|
2024-09-05 18:22:15 -07:00
|
|
|
@contextlib.contextmanager
|
|
|
|
def jax_logging_level(value):
|
|
|
|
# jax_logging_level doesn't have a context manager, because it's
|
|
|
|
# not thread-safe. But since tests are always single-threaded, we
|
|
|
|
# can define one here.
|
|
|
|
original_value = jax.config.jax_logging_level
|
|
|
|
jax.config.update("jax_logging_level", value)
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
jax.config.update("jax_logging_level", original_value)
|
|
|
|
|
2024-05-29 13:38:33 -07:00
|
|
|
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
@contextlib.contextmanager
|
|
|
|
def capture_jax_logs():
|
|
|
|
log_output = io.StringIO()
|
2024-09-05 18:22:15 -07:00
|
|
|
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
handler = logging.StreamHandler(log_output)
|
|
|
|
logger = logging.getLogger("jax")
|
|
|
|
|
|
|
|
logger.addHandler(handler)
|
|
|
|
try:
|
|
|
|
yield log_output
|
|
|
|
finally:
|
|
|
|
logger.removeHandler(handler)
|
|
|
|
|
|
|
|
|
|
|
|
class LoggingTest(jtu.JaxTestCase):
|
|
|
|
|
2024-01-04 13:46:50 -08:00
|
|
|
@unittest.skipIf(platform.system() == "Windows",
|
|
|
|
"Subprocess test doesn't work on Windows")
|
2023-12-14 15:21:07 -08:00
|
|
|
def test_no_log_spam(self):
|
|
|
|
if jtu.is_cloud_tpu() and xla_bridge._backends:
|
|
|
|
raise self.skipTest(
|
|
|
|
"test requires fresh process on Cloud TPU because only one process "
|
|
|
|
"can use the TPU at a time")
|
|
|
|
if sys.executable is None:
|
|
|
|
raise self.skipTest("test requires access to python binary")
|
|
|
|
|
Introduce hermetic CUDA in Google ML projects.
1) Hermetic CUDA rules allow building wheels with GPU support on a machine without GPUs, as well as running Bazel GPU tests on a machine with only GPUs and NVIDIA driver installed. When `--config=cuda` is provided in Bazel options, Bazel will download CUDA, CUDNN and NCCL redistributions in the cache, and use them during build and test phases.
[Default location of CUNN redistributions](https://developer.download.nvidia.com/compute/cudnn/redist/)
[Default location of CUDA redistributions](https://developer.download.nvidia.com/compute/cuda/redist/)
[Default location of NCCL redistributions](https://pypi.org/project/nvidia-nccl-cu12/#history)
2) To include hermetic CUDA rules in your project, add the following in the WORKSPACE of the downstream project dependent on XLA.
Note: use `@local_tsl` instead of `@tsl` in Tensorflow project.
```
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",
)
cuda_json_init_repository()
load(
"@cuda_redist_json//:distributions.bzl",
"CUDA_REDISTRIBUTIONS",
"CUDNN_REDISTRIBUTIONS",
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
"cuda_redist_init_repositories",
"cudnn_redist_init_repository",
)
cuda_redist_init_repositories(
cuda_redistributions = CUDA_REDISTRIBUTIONS,
)
cudnn_redist_init_repository(
cudnn_redistributions = CUDNN_REDISTRIBUTIONS,
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
"cuda_configure",
)
cuda_configure(name = "local_config_cuda")
load(
"@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
"nccl_redist_init_repository",
)
nccl_redist_init_repository()
load(
"@tsl//third_party/nccl/hermetic:nccl_configure.bzl",
"nccl_configure",
)
nccl_configure(name = "local_config_nccl")
```
PiperOrigin-RevId: 662981325
2024-08-14 10:57:53 -07:00
|
|
|
# Save script in file to fix the problem with
|
|
|
|
# `tsl::Env::Default()->GetExecutablePath()` not working properly with
|
|
|
|
# command flag.
|
|
|
|
with tempfile.NamedTemporaryFile(
|
|
|
|
mode="w+", encoding="utf-8", suffix=".py"
|
|
|
|
) as f:
|
|
|
|
f.write(textwrap.dedent("""
|
|
|
|
import jax
|
|
|
|
jax.device_count()
|
|
|
|
f = jax.jit(lambda x: x + 1)
|
|
|
|
f(1)
|
|
|
|
f(2)
|
|
|
|
jax.numpy.add(1, 1)
|
|
|
|
"""))
|
|
|
|
python = sys.executable
|
|
|
|
assert "python" in python
|
|
|
|
# Make sure C++ logging is at default level for the test process.
|
2024-09-05 18:22:15 -07:00
|
|
|
proc = subprocess.run([python, f.name], capture_output=True)
|
Introduce hermetic CUDA in Google ML projects.
1) Hermetic CUDA rules allow building wheels with GPU support on a machine without GPUs, as well as running Bazel GPU tests on a machine with only GPUs and NVIDIA driver installed. When `--config=cuda` is provided in Bazel options, Bazel will download CUDA, CUDNN and NCCL redistributions in the cache, and use them during build and test phases.
[Default location of CUNN redistributions](https://developer.download.nvidia.com/compute/cudnn/redist/)
[Default location of CUDA redistributions](https://developer.download.nvidia.com/compute/cuda/redist/)
[Default location of NCCL redistributions](https://pypi.org/project/nvidia-nccl-cu12/#history)
2) To include hermetic CUDA rules in your project, add the following in the WORKSPACE of the downstream project dependent on XLA.
Note: use `@local_tsl` instead of `@tsl` in Tensorflow project.
```
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",
)
cuda_json_init_repository()
load(
"@cuda_redist_json//:distributions.bzl",
"CUDA_REDISTRIBUTIONS",
"CUDNN_REDISTRIBUTIONS",
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
"cuda_redist_init_repositories",
"cudnn_redist_init_repository",
)
cuda_redist_init_repositories(
cuda_redistributions = CUDA_REDISTRIBUTIONS,
)
cudnn_redist_init_repository(
cudnn_redistributions = CUDNN_REDISTRIBUTIONS,
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
"cuda_configure",
)
cuda_configure(name = "local_config_cuda")
load(
"@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
"nccl_redist_init_repository",
)
nccl_redist_init_repository()
load(
"@tsl//third_party/nccl/hermetic:nccl_configure.bzl",
"nccl_configure",
)
nccl_configure(name = "local_config_nccl")
```
PiperOrigin-RevId: 662981325
2024-08-14 10:57:53 -07:00
|
|
|
|
|
|
|
lines = proc.stdout.split(b"\n")
|
|
|
|
lines.extend(proc.stderr.split(b"\n"))
|
|
|
|
allowlist = [
|
|
|
|
b"",
|
|
|
|
(
|
|
|
|
b"An NVIDIA GPU may be present on this machine, but a"
|
|
|
|
b" CUDA-enabled jaxlib is not installed. Falling back to cpu."
|
|
|
|
),
|
|
|
|
]
|
|
|
|
lines = [l for l in lines if l not in allowlist]
|
|
|
|
self.assertEmpty(lines)
|
2023-12-14 15:21:07 -08:00
|
|
|
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
def test_debug_logging(self):
|
|
|
|
# Warmup so we don't get "No GPU/TPU" warning later.
|
|
|
|
jax.jit(lambda x: x + 1)(1)
|
|
|
|
|
|
|
|
# Nothing logged by default (except warning messages, which we don't expect
|
|
|
|
# here).
|
|
|
|
with capture_jax_logs() as log_output:
|
|
|
|
jax.jit(lambda x: x + 1)(1)
|
|
|
|
self.assertEmpty(log_output.getvalue())
|
|
|
|
|
|
|
|
# Turn on all debug logging.
|
2024-05-29 13:38:33 -07:00
|
|
|
with jax_debug_log_modules("jax"):
|
|
|
|
with capture_jax_logs() as log_output:
|
|
|
|
jax.jit(lambda x: x + 1)(1)
|
|
|
|
self.assertIn("Finished tracing + transforming", log_output.getvalue())
|
|
|
|
self.assertIn("Compiling <lambda>", log_output.getvalue())
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
|
|
|
|
# Turn off all debug logging.
|
2024-06-25 09:02:32 -07:00
|
|
|
with jax_debug_log_modules(""):
|
2024-05-29 13:38:33 -07:00
|
|
|
with capture_jax_logs() as log_output:
|
|
|
|
jax.jit(lambda x: x + 1)(1)
|
|
|
|
self.assertEmpty(log_output.getvalue())
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
|
|
|
|
# Turn on one module.
|
2024-05-29 13:38:33 -07:00
|
|
|
with jax_debug_log_modules("jax._src.dispatch"):
|
|
|
|
with capture_jax_logs() as log_output:
|
|
|
|
jax.jit(lambda x: x + 1)(1)
|
|
|
|
self.assertIn("Finished tracing + transforming", log_output.getvalue())
|
|
|
|
self.assertNotIn("Compiling <lambda>", log_output.getvalue())
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
|
|
|
|
# Turn everything off again.
|
2024-06-25 09:02:32 -07:00
|
|
|
with jax_debug_log_modules(""):
|
2024-05-29 13:38:33 -07:00
|
|
|
with capture_jax_logs() as log_output:
|
|
|
|
jax.jit(lambda x: x + 1)(1)
|
|
|
|
self.assertEmpty(log_output.getvalue())
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
|
2024-11-05 13:28:17 -08:00
|
|
|
@jtu.skip_on_devices("tpu")
|
2024-09-05 18:22:15 -07:00
|
|
|
@unittest.skipIf(platform.system() == "Windows",
|
|
|
|
"Subprocess test doesn't work on Windows")
|
|
|
|
def test_subprocess_stderr_info_logging(self):
|
|
|
|
if sys.executable is None:
|
|
|
|
raise self.skipTest("test requires access to python binary")
|
|
|
|
|
|
|
|
program = """
|
|
|
|
import jax # this prints INFO logging from backend imports
|
|
|
|
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
|
|
|
|
"""
|
|
|
|
|
|
|
|
# strip the leading whitespace from the program script
|
|
|
|
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
|
|
|
|
|
|
|
# test INFO
|
|
|
|
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
|
|
|
|
f" '{program}'")
|
|
|
|
p = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
log_output = p.stderr
|
|
|
|
info_lines = log_output.split("\n")
|
|
|
|
self.assertGreater(len(info_lines), 0)
|
|
|
|
self.assertIn("INFO", log_output)
|
|
|
|
self.assertNotIn("DEBUG", log_output)
|
|
|
|
|
2024-11-05 13:28:17 -08:00
|
|
|
@jtu.skip_on_devices("tpu")
|
2024-09-05 18:22:15 -07:00
|
|
|
@unittest.skipIf(platform.system() == "Windows",
|
|
|
|
"Subprocess test doesn't work on Windows")
|
|
|
|
def test_subprocess_stderr_debug_logging(self):
|
|
|
|
if sys.executable is None:
|
|
|
|
raise self.skipTest("test requires access to python binary")
|
|
|
|
|
|
|
|
program = """
|
|
|
|
import jax # this prints INFO logging from backend imports
|
|
|
|
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
|
|
|
|
"""
|
|
|
|
|
|
|
|
# strip the leading whitespace from the program script
|
|
|
|
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
|
|
|
|
|
|
|
# test DEBUG
|
|
|
|
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
|
|
|
|
f" '{program}'")
|
|
|
|
p = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
log_output = p.stderr
|
|
|
|
self.assertIn("INFO", log_output)
|
|
|
|
self.assertIn("DEBUG", log_output)
|
|
|
|
|
|
|
|
# test JAX_DEBUG_MODULES
|
|
|
|
cmd = shlex.split(f"env JAX_DEBUG_LOG_MODULES=jax {sys.executable} -c"
|
|
|
|
f" '{program}'")
|
|
|
|
p = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
log_output = p.stderr
|
|
|
|
self.assertIn("DEBUG", log_output)
|
|
|
|
|
2024-11-05 13:28:17 -08:00
|
|
|
@jtu.skip_on_devices("tpu")
|
2024-09-05 18:22:15 -07:00
|
|
|
@unittest.skipIf(platform.system() == "Windows",
|
|
|
|
"Subprocess test doesn't work on Windows")
|
|
|
|
def test_subprocess_toggling_logging_level(self):
|
|
|
|
if sys.executable is None:
|
|
|
|
raise self.skipTest("test requires access to python binary")
|
|
|
|
|
|
|
|
_separator = "---------------------------"
|
|
|
|
program = f"""
|
|
|
|
import sys
|
|
|
|
import jax # this prints INFO logging from backend imports
|
|
|
|
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
|
|
|
|
jax.config.update("jax_logging_level", None)
|
|
|
|
sys.stderr.write("{_separator}")
|
|
|
|
jax.jit(lambda x: x)(1) # should not log anything now
|
|
|
|
"""
|
|
|
|
|
|
|
|
# strip the leading whitespace from the program script
|
|
|
|
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
|
|
|
|
|
|
|
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
|
|
|
|
f" '{program}'")
|
|
|
|
p = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
log_output = p.stderr
|
|
|
|
m = re.search(_separator, log_output)
|
|
|
|
self.assertTrue(m is not None)
|
|
|
|
log_output_verbose = log_output[:m.start()]
|
|
|
|
log_output_silent = log_output[m.end():]
|
|
|
|
|
|
|
|
self.assertIn("Finished tracing + transforming <lambda> for pjit",
|
|
|
|
log_output_verbose)
|
|
|
|
self.assertEqual(log_output_silent, "")
|
|
|
|
|
2024-11-05 13:28:17 -08:00
|
|
|
@jtu.skip_on_devices("tpu")
|
2024-09-05 18:22:15 -07:00
|
|
|
@unittest.skipIf(platform.system() == "Windows",
|
|
|
|
"Subprocess test doesn't work on Windows")
|
|
|
|
def test_subprocess_double_logging_absent(self):
|
|
|
|
if sys.executable is None:
|
|
|
|
raise self.skipTest("test requires access to python binary")
|
|
|
|
|
|
|
|
program = """
|
|
|
|
import jax # this prints INFO logging from backend imports
|
|
|
|
jax.config.update("jax_debug_log_modules", "jax._src.compiler,jax._src.dispatch")
|
|
|
|
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
|
|
|
|
"""
|
|
|
|
|
|
|
|
# strip the leading whitespace from the program script
|
|
|
|
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
|
|
|
|
|
|
|
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
|
|
|
|
f" '{program}'")
|
|
|
|
p = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
log_output = p.stderr
|
|
|
|
self.assertNotEmpty(log_output)
|
|
|
|
log_lines = log_output.strip().split("\n")
|
|
|
|
# only one tracing line should be printed, if there's more than one
|
|
|
|
# then logs are printing duplicated
|
|
|
|
self.assertLen([line for line in log_lines
|
|
|
|
if "Finished tracing + transforming" in line], 1)
|
|
|
|
|
2024-11-05 13:28:17 -08:00
|
|
|
@jtu.skip_on_devices("tpu")
|
2024-09-05 18:22:15 -07:00
|
|
|
@unittest.skipIf(platform.system() == "Windows",
|
|
|
|
"Subprocess test doesn't work on Windows")
|
|
|
|
def test_subprocess_cpp_logging_level(self):
|
|
|
|
if sys.executable is None:
|
|
|
|
raise self.skipTest("test requires access to python binary")
|
|
|
|
|
|
|
|
program = """
|
|
|
|
import sys
|
|
|
|
import jax # this prints INFO logging from backend imports
|
|
|
|
jax.distributed.initialize("127.0.0.1:12345", num_processes=1, process_id=0)
|
|
|
|
"""
|
|
|
|
|
|
|
|
# strip the leading whitespace from the program script
|
|
|
|
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
|
|
|
|
|
|
|
|
# verbose logging: DEBUG, VERBOSE
|
|
|
|
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
|
|
|
|
f" '{program}'")
|
|
|
|
p = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
self.assertIn("Initializing CoordinationService", p.stderr)
|
|
|
|
|
|
|
|
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
|
|
|
|
f" '{program}'")
|
|
|
|
p = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
self.assertIn("Initializing CoordinationService", p.stderr)
|
|
|
|
|
|
|
|
# verbose logging: WARNING, None
|
|
|
|
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=WARNING {sys.executable} -c"
|
|
|
|
f" '{program}'")
|
|
|
|
p = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
self.assertNotIn("Initializing CoordinationService", p.stderr)
|
|
|
|
|
|
|
|
cmd = shlex.split(f"{sys.executable} -c"
|
|
|
|
f" '{program}'")
|
|
|
|
p = subprocess.run(cmd, capture_output=True, text=True)
|
2024-11-05 13:28:17 -08:00
|
|
|
if int(_default_TF_CPP_MIN_LOG_LEVEL) >= 1:
|
|
|
|
self.assertNotIn("Initializing CoordinationService", p.stderr)
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|