mirror of
https://github.com/ROCm/jax.git
synced 2025-04-13 02:16:06 +00:00
Warn the user if transparent huge pages aren't enabled.
PiperOrigin-RevId: 735431881
This commit is contained in:
parent
14b215fe76
commit
5cb29949d4
@ -15,6 +15,7 @@
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from jax import version
|
||||
from jax._src import config
|
||||
from jax._src import hardware_utils
|
||||
@ -72,7 +73,19 @@ def cloud_tpu_init() -> None:
|
||||
|
||||
# Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed.
|
||||
libtpu_path = get_tpu_library_path()
|
||||
num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0]
|
||||
num_tpu_chips, tpu_id = hardware_utils.num_available_tpu_chips_and_device_id()
|
||||
if (
|
||||
tpu_id is not None
|
||||
and tpu_id >= hardware_utils.TpuVersion.v5e
|
||||
and not hardware_utils.transparent_hugepages_enabled()
|
||||
):
|
||||
warnings.warn(
|
||||
'Transparent hugepages are not enabled. TPU runtime startup and'
|
||||
' shutdown time should be significantly improved on TPU v5e and newer.'
|
||||
' If not already set, you may need to enable transparent hugepages in'
|
||||
' your VM image (sudo sh -c "echo always >'
|
||||
' /sys/kernel/mm/transparent_hugepage/enabled")'
|
||||
)
|
||||
if (libtpu_path is None or num_tpu_chips == 0) and not jax_force_tpu_init():
|
||||
return
|
||||
|
||||
|
@ -12,25 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
import os
|
||||
import pathlib
|
||||
import glob
|
||||
|
||||
_GOOGLE_PCI_VENDOR_ID = '0x1ae0'
|
||||
_TPU_PCI_DEVICE_IDS = [
|
||||
# TPU v2, v3
|
||||
'0x0027',
|
||||
# No public name (plc)
|
||||
'0x0056',
|
||||
# TPU v4
|
||||
'0x005e',
|
||||
# TPU v5p
|
||||
'0x0062',
|
||||
# TPU v5e
|
||||
'0x0063',
|
||||
# TPU v6e
|
||||
'0x006f',
|
||||
]
|
||||
|
||||
_NVIDIA_GPU_DEVICES = [
|
||||
'/dev/nvidia0',
|
||||
@ -38,10 +25,36 @@ _NVIDIA_GPU_DEVICES = [
|
||||
'/dev/dxg', # WSL2
|
||||
]
|
||||
|
||||
|
||||
class TpuVersion(enum.IntEnum):
|
||||
# TPU v2, v3
|
||||
v2 = 0
|
||||
v3 = 1
|
||||
# No public name (plc)
|
||||
plc = 2
|
||||
# TPU v4
|
||||
v4 = 3
|
||||
# TPU v5p
|
||||
v5p = 4
|
||||
# TPU v5e
|
||||
v5e = 5
|
||||
# TPU v6e
|
||||
v6e = 6
|
||||
|
||||
|
||||
_TPU_PCI_DEVICE_IDS = {
|
||||
'0x0027': TpuVersion.v3,
|
||||
'0x0056': TpuVersion.plc,
|
||||
'0x005e': TpuVersion.v4,
|
||||
'0x0062': TpuVersion.v5p,
|
||||
'0x0063': TpuVersion.v5e,
|
||||
'0x006f': TpuVersion.v6e,
|
||||
}
|
||||
|
||||
def num_available_tpu_chips_and_device_id():
|
||||
"""Returns the device id and number of TPU chips attached through PCI."""
|
||||
num_chips = 0
|
||||
device_id = ''
|
||||
tpu_version = None
|
||||
for vendor_path in glob.glob('/sys/bus/pci/devices/*/vendor'):
|
||||
vendor_id = pathlib.Path(vendor_path).read_text().strip()
|
||||
if vendor_id != _GOOGLE_PCI_VENDOR_ID:
|
||||
@ -50,12 +63,20 @@ def num_available_tpu_chips_and_device_id():
|
||||
device_path = os.path.join(os.path.dirname(vendor_path), 'device')
|
||||
device_id = pathlib.Path(device_path).read_text().strip()
|
||||
if device_id in _TPU_PCI_DEVICE_IDS:
|
||||
tpu_version = _TPU_PCI_DEVICE_IDS[device_id]
|
||||
num_chips += 1
|
||||
|
||||
return num_chips, device_id
|
||||
return num_chips, tpu_version
|
||||
|
||||
|
||||
def has_visible_nvidia_gpu() -> bool:
|
||||
"""True if there's a visible nvidia gpu available on device, False otherwise."""
|
||||
|
||||
return any(os.path.exists(d) for d in _NVIDIA_GPU_DEVICES)
|
||||
|
||||
|
||||
def transparent_hugepages_enabled() -> bool:
|
||||
# See https://docs.kernel.org/admin-guide/mm/transhuge.html for more
|
||||
# information about transparent huge pages.
|
||||
path = pathlib.Path('/sys/kernel/mm/transparent_hugepage/enabled')
|
||||
return path.exists() and path.read_text().strip() == '[always] madvise never'
|
||||
|
@ -66,6 +66,9 @@ filterwarnings = [
|
||||
# https://github.com/protocolbuffers/protobuf/issues/12186#issuecomment-1745679358
|
||||
"ignore:Type google\\._upb\\._message\\.(Scalar|Message)MapContainer uses PyType_Spec with a metaclass that has custom tp_new\\. This is deprecated and will no longer be allowed in Python 3\\.14\\.:DeprecationWarning",
|
||||
|
||||
# TODO(b/401588349): Remove this once transparent hugepages are enabled.
|
||||
"ignore:Transparent hugepages",
|
||||
|
||||
# NOTE: this is probably not where you want to add code to suppress a
|
||||
# warning. Only pytest tests look at this list, whereas Bazel tests also
|
||||
# check for warnings and do not check this list. Most likely, you should
|
||||
|
Loading…
x
Reference in New Issue
Block a user