mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Add a CPU feature guard module to JAX.
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module. Fixes https://github.com/google/jax/issues/6671 PiperOrigin-RevId: 374683190
This commit is contained in:
parent
ae35a09545
commit
d481013f47
@ -31,6 +31,7 @@ py_binary(
|
||||
"//jaxlib",
|
||||
"//jaxlib:setup.py",
|
||||
"//jaxlib:setup.cfg",
|
||||
"//jaxlib:cpu_feature_guard.so",
|
||||
"//jaxlib:lapack.so",
|
||||
"//jaxlib:_pocketfft.so",
|
||||
"//jaxlib:pocketfft_flatbuffers_py",
|
||||
|
@ -154,9 +154,9 @@ def verify_mac_libraries_dont_reference_chkstack():
|
||||
if nm.returncode != 0:
|
||||
raise RuntimeError(f"nm process failed: {nm.stdout} {nm.stderr}")
|
||||
if "____chkstk_darwin" in nm.stdout:
|
||||
raise RuntimeError(
|
||||
"Mac wheel incorrectly depends on symbol ____chkstk_darwin, which "
|
||||
"means that it isn't compatible with older MacOS versions.")
|
||||
raise RuntimeError(
|
||||
"Mac wheel incorrectly depends on symbol ____chkstk_darwin, which "
|
||||
"means that it isn't compatible with older MacOS versions.")
|
||||
|
||||
|
||||
def prepare_wheel(sources_path):
|
||||
@ -172,6 +172,7 @@ def prepare_wheel(sources_path):
|
||||
copy_file(r.Rlocation("__main__/jaxlib/setup.cfg"), dst_dir=sources_path)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/init.py"),
|
||||
dst_filename="__init__.py")
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cpu_feature_guard.so"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/lapack.so"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_pocketfft.so"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/pocketfft_flatbuffers_py_generated.py"))
|
||||
|
@ -46,6 +46,7 @@ pytype_library(
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//third_party/py/jax/jaxlib:_pocketfft",
|
||||
"//third_party/py/jax/jaxlib:cpu_feature_guard",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
|
||||
],
|
||||
)
|
||||
|
@ -58,6 +58,10 @@ def _check_jaxlib_version():
|
||||
|
||||
_check_jaxlib_version()
|
||||
|
||||
if version >= (0, 1, 68):
|
||||
from jaxlib import cpu_feature_guard
|
||||
cpu_feature_guard.check_cpu_features()
|
||||
|
||||
from jaxlib import xla_client
|
||||
from jaxlib import lapack
|
||||
from jaxlib import pocketfft
|
||||
|
@ -346,3 +346,12 @@ pybind_extension(
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "cpu_feature_guard",
|
||||
srcs = ["cpu_feature_guard.c"],
|
||||
module_name = "cpu_feature_guard",
|
||||
deps = [
|
||||
"@org_tensorflow//third_party/python_runtime:headers",
|
||||
],
|
||||
)
|
||||
|
160
jaxlib/cpu_feature_guard.c
Normal file
160
jaxlib/cpu_feature_guard.c
Normal file
@ -0,0 +1,160 @@
|
||||
/* Copyright 2021 Google LLC
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || \
|
||||
defined(_M_X64)
|
||||
#define PLATFORM_IS_X86
|
||||
#endif
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define PLATFORM_WINDOWS
|
||||
#endif
|
||||
|
||||
// SIMD extension querying is only available on x86.
|
||||
#ifdef PLATFORM_IS_X86
|
||||
|
||||
#ifdef PLATFORM_WINDOWS
|
||||
#if defined(_MSC_VER)
|
||||
#include <intrin.h>
|
||||
#endif
|
||||
|
||||
// Visual Studio defines a builtin function for CPUID, so use that if possible.
|
||||
#define GETCPUID(a, b, c, d, a_inp, c_inp) \
|
||||
{ \
|
||||
int cpu_info[4] = {-1}; \
|
||||
__cpuidex(cpu_info, a_inp, c_inp); \
|
||||
a = cpu_info[0]; \
|
||||
b = cpu_info[1]; \
|
||||
c = cpu_info[2]; \
|
||||
d = cpu_info[3]; \
|
||||
}
|
||||
|
||||
// Visual Studio defines a builtin function, so use that if possible.
|
||||
static int GetXCR0EAX() { return _xgetbv(0); }
|
||||
|
||||
#else
|
||||
|
||||
// Otherwise use gcc-format assembler to implement the underlying instructions.
|
||||
#define GETCPUID(a, b, c, d, a_inp, c_inp) \
|
||||
asm("mov %%rbx, %%rdi\n" \
|
||||
"cpuid\n" \
|
||||
"xchg %%rdi, %%rbx\n" \
|
||||
: "=a"(a), "=D"(b), "=c"(c), "=d"(d) \
|
||||
: "a"(a_inp), "2"(c_inp))
|
||||
#endif
|
||||
|
||||
static int GetXCR0EAX() {
|
||||
int eax, edx;
|
||||
asm("XGETBV" : "=a"(eax), "=d"(edx) : "c"(0));
|
||||
return eax;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// TODO(phawkins): technically we should build this module without AVX support
|
||||
// and use configure-time tests instead of __AVX__, since there is a
|
||||
// possibility that the compiler will use AVX instructions before we reach this
|
||||
// point.
|
||||
#ifdef PLATFORM_IS_X86
|
||||
|
||||
static void ReportMissingCpuFeature(const char* name) {
|
||||
PyErr_Format(
|
||||
PyExc_RuntimeError,
|
||||
"This version of jaxlib was built using %s instructions, which your "
|
||||
"CPU and/or operating system do not support. You may be able work around "
|
||||
"this issue by building jaxlib from source.", name);
|
||||
}
|
||||
|
||||
static PyObject *CheckCpuFeatures(PyObject *self, PyObject *args) {
|
||||
uint32_t eax, ebx, ecx, edx;
|
||||
|
||||
// To get general information and extended features we send eax = 1 and
|
||||
// ecx = 0 to cpuid. The response is returned in eax, ebx, ecx and edx.
|
||||
// (See Intel 64 and IA-32 Architectures Software Developer's Manual
|
||||
// Volume 2A: Instruction Set Reference, A-M CPUID).
|
||||
GETCPUID(eax, ebx, ecx, edx, 1, 0);
|
||||
const uint64_t xcr0_xmm_mask = 0x2;
|
||||
const uint64_t xcr0_ymm_mask = 0x4;
|
||||
const uint64_t xcr0_avx_mask = xcr0_xmm_mask | xcr0_ymm_mask;
|
||||
const _Bool have_avx =
|
||||
// Does the OS support XGETBV instruction use by applications?
|
||||
((ecx >> 27) & 0x1) &&
|
||||
// Does the OS save/restore XMM and YMM state?
|
||||
((GetXCR0EAX() & xcr0_avx_mask) == xcr0_avx_mask) &&
|
||||
// Is AVX supported in hardware?
|
||||
((ecx >> 28) & 0x1);
|
||||
const _Bool have_fma = have_avx && ((ecx >> 12) & 0x1);
|
||||
|
||||
// Get standard level 7 structured extension features (issue CPUID with
|
||||
// eax = 7 and ecx= 0), which is required to check for AVX2 support as
|
||||
// well as other Haswell (and beyond) features. (See Intel 64 and IA-32
|
||||
// Architectures Software Developer's Manual Volume 2A: Instruction Set
|
||||
// Reference, A-M CPUID).
|
||||
GETCPUID(eax, ebx, ecx, edx, 7, 0);
|
||||
const _Bool have_avx2 = have_avx && ((ebx >> 5) & 0x1);
|
||||
|
||||
#ifdef __AVX__
|
||||
if (!have_avx) {
|
||||
ReportMissingCpuFeature("AVX");
|
||||
return NULL;
|
||||
}
|
||||
#endif // __AVX__
|
||||
|
||||
#ifdef __AVX2__
|
||||
if (!have_avx2) {
|
||||
ReportMissingCpuFeature("AVX2");
|
||||
return NULL;
|
||||
}
|
||||
#endif // __AVX2__
|
||||
|
||||
#ifdef __FMA__
|
||||
if (!have_fma) {
|
||||
ReportMissingCpuFeature("FMA");
|
||||
return NULL;
|
||||
}
|
||||
#endif // __FMA__
|
||||
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
}
|
||||
|
||||
#else // PLATFORM_IS_X86
|
||||
|
||||
static PyObject *CheckCpuFeatures(PyObject *self, PyObject *args) {
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
}
|
||||
|
||||
#endif // PLATFORM_IS_X86
|
||||
|
||||
static PyMethodDef cpu_feature_guard_methods[] = {
|
||||
{"check_cpu_features", CheckCpuFeatures, METH_NOARGS,
|
||||
"Throws an exception if the CPU is missing instructions used by jaxlib."},
|
||||
{NULL, NULL, 0, NULL}};
|
||||
|
||||
static struct PyModuleDef cpu_feature_guard_module = {
|
||||
PyModuleDef_HEAD_INIT, "cpu_feature_guard", /* name of module */
|
||||
NULL, -1, /* size of per-interpreter state of the module,
|
||||
or -1 if the module keeps state in global variables. */
|
||||
cpu_feature_guard_methods};
|
||||
|
||||
PyMODINIT_FUNC PyInit_cpu_feature_guard(void) {
|
||||
return PyModule_Create(&cpu_feature_guard_module);
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user