mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Drop flatbuffers as a Python dependency of JAX.
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API. Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work. PiperOrigin-RevId: 457460347
This commit is contained in:
parent
93f5113c93
commit
efefeac450
@ -180,7 +180,6 @@ def prepare_wheel(sources_path):
|
||||
copy_to_jaxlib(f"__main__/jaxlib/_lapack.{pyext}")
|
||||
copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py")
|
||||
copy_to_jaxlib(f"__main__/jaxlib/_pocketfft.{pyext}")
|
||||
copy_to_jaxlib("__main__/jaxlib/pocketfft_flatbuffers_py_generated.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/pocketfft.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_prng.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py")
|
||||
|
@ -1,6 +1,5 @@
|
||||
cloudpickle
|
||||
colorama>=0.4.4
|
||||
flatbuffers==2.0
|
||||
# TODO(jakevdp): fix use of deprecated NEAREST resampling for more recent pillow.
|
||||
pillow>=8.3.1,<9.1.0
|
||||
pytest-benchmark
|
||||
|
@ -17,7 +17,6 @@
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"flatbuffer_cc_library",
|
||||
"flatbuffer_py_library",
|
||||
"pybind_extension",
|
||||
)
|
||||
|
||||
@ -86,7 +85,6 @@ py_library(
|
||||
":_lapack",
|
||||
":_pocketfft",
|
||||
":cpu_feature_guard",
|
||||
":pocketfft_flatbuffers_py",
|
||||
],
|
||||
)
|
||||
|
||||
@ -148,11 +146,6 @@ flatbuffer_cc_library(
|
||||
srcs = ["pocketfft.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_py_library(
|
||||
name = "pocketfft_flatbuffers_py",
|
||||
srcs = ["pocketfft.fbs"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pocketfft_kernels",
|
||||
srcs = ["pocketfft_kernels.cc"],
|
||||
@ -178,7 +171,9 @@ pybind_extension(
|
||||
module_name = "_pocketfft",
|
||||
deps = [
|
||||
":kernel_pybind11_helpers",
|
||||
":pocketfft_flatbuffers_cc",
|
||||
":pocketfft_kernels",
|
||||
"@flatbuffers//:runtime_cc",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
@ -18,7 +18,7 @@ load("@org_tensorflow//tensorflow/core/platform/default:build_config.bzl", _pyx_
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", _pybind_extension = "pybind_extension")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
|
||||
load("@flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library", _flatbuffer_py_library = "flatbuffer_py_library")
|
||||
load("@flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
|
||||
|
||||
# Explicitly re-exports names to avoid "unused variable" warnings from .bzl
|
||||
# lint tools.
|
||||
@ -30,7 +30,6 @@ pybind_extension = _pybind_extension
|
||||
if_cuda_is_configured = _if_cuda_is_configured
|
||||
if_rocm_is_configured = _if_rocm_is_configured
|
||||
flatbuffer_cc_library = _flatbuffer_cc_library
|
||||
flatbuffer_py_library = _flatbuffer_py_library
|
||||
|
||||
def py_extension(name, srcs, copts, deps):
|
||||
pybind_extension(name, srcs = srcs, copts = copts, deps = deps, module_name = name)
|
||||
|
@ -14,21 +14,55 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <complex>
|
||||
#include <vector>
|
||||
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "jaxlib/pocketfft_generated.h"
|
||||
#include "jaxlib/pocketfft_kernels.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace jax {
|
||||
namespace {
|
||||
|
||||
pybind11::dict Registrations() {
|
||||
py::bytes BuildPocketFftDescriptor(const std::vector<uint64_t>& shape,
|
||||
bool is_double, int fft_type,
|
||||
const std::vector<uint64_t>& fft_lengths,
|
||||
const std::vector<uint64_t>& strides_in,
|
||||
const std::vector<uint64_t>& strides_out,
|
||||
const std::vector<uint32_t>& axes,
|
||||
bool forward, double scale) {
|
||||
PocketFftDescriptorT descriptor;
|
||||
descriptor.shape = shape;
|
||||
descriptor.fft_type = static_cast<PocketFftType>(fft_type);
|
||||
descriptor.dtype =
|
||||
is_double ? PocketFftDtype_COMPLEX128 : PocketFftDtype_COMPLEX64;
|
||||
descriptor.strides_in = strides_in;
|
||||
descriptor.strides_out = strides_out;
|
||||
descriptor.axes = axes;
|
||||
descriptor.forward = forward;
|
||||
descriptor.scale = scale;
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
fbb.Finish(PocketFftDescriptor::Pack(fbb, &descriptor));
|
||||
return py::bytes(reinterpret_cast<char*>(fbb.GetBufferPointer()),
|
||||
fbb.GetSize());
|
||||
}
|
||||
|
||||
py::dict Registrations() {
|
||||
pybind11::dict dict;
|
||||
dict["pocketfft"] = EncapsulateFunction(PocketFft);
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_pocketfft, m) { m.def("registrations", &Registrations); }
|
||||
PYBIND11_MODULE(_pocketfft, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("pocketfft_descriptor", &BuildPocketFftDescriptor, py::arg("shape"),
|
||||
py::arg("is_double"), py::arg("fft_type"), py::arg("fft_lengths"),
|
||||
py::arg("strides_in"), py::arg("strides_out"), py::arg("axes"),
|
||||
py::arg("forward"), py::arg("scale"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace jax
|
||||
|
@ -13,8 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import jax
|
||||
# flatbuffers needs importlib.util but fails to import it itself.
|
||||
import importlib.util # noqa: F401
|
||||
from typing import List
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
@ -23,10 +21,8 @@ import jaxlib.mlir.dialects.mhlo as mhlo
|
||||
|
||||
from .mhlo_helpers import custom_call
|
||||
from . import _pocketfft
|
||||
from . import pocketfft_flatbuffers_py_generated as pd
|
||||
import numpy as np
|
||||
|
||||
import flatbuffers
|
||||
from jaxlib import xla_client
|
||||
|
||||
for _name, _value in _pocketfft.registrations().items():
|
||||
@ -34,8 +30,10 @@ for _name, _value in _pocketfft.registrations().items():
|
||||
|
||||
FftType = xla_client.FftType
|
||||
|
||||
flatbuffers_version_2 = hasattr(flatbuffers, "__version__")
|
||||
|
||||
_C2C = 0
|
||||
_C2R = 1
|
||||
_R2C = 2
|
||||
|
||||
def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
|
||||
fft_lengths: List[int]) -> bytes:
|
||||
@ -43,43 +41,34 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
|
||||
assert len(fft_lengths) >= 1
|
||||
assert len(fft_lengths) <= n, (fft_lengths, n)
|
||||
|
||||
builder = flatbuffers.Builder(128)
|
||||
|
||||
forward = fft_type in (FftType.FFT, FftType.RFFT)
|
||||
is_double = np.finfo(dtype).dtype == np.float64
|
||||
if fft_type == FftType.RFFT:
|
||||
pocketfft_type = pd.PocketFftType.R2C
|
||||
pocketfft_type = _R2C
|
||||
|
||||
assert dtype in (np.float32, np.float64), dtype
|
||||
out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128)
|
||||
pocketfft_dtype = (
|
||||
pd.PocketFftDtype.COMPLEX64
|
||||
if dtype == np.float32 else pd.PocketFftDtype.COMPLEX128)
|
||||
|
||||
assert shape[-len(fft_lengths):] == fft_lengths, (shape, fft_lengths)
|
||||
out_shape = list(shape)
|
||||
out_shape[-1] = out_shape[-1] // 2 + 1
|
||||
|
||||
elif fft_type == FftType.IRFFT:
|
||||
pocketfft_type = pd.PocketFftType.C2R
|
||||
pocketfft_type = _C2R
|
||||
assert np.issubdtype(dtype, np.complexfloating), dtype
|
||||
|
||||
out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64)
|
||||
pocketfft_dtype = (
|
||||
pd.PocketFftDtype.COMPLEX64
|
||||
if dtype == np.complex64 else pd.PocketFftDtype.COMPLEX128)
|
||||
|
||||
assert shape[-len(fft_lengths):-1] == fft_lengths[:-1]
|
||||
out_shape = list(shape)
|
||||
out_shape[-1] = fft_lengths[-1]
|
||||
assert (out_shape[-1] // 2 + 1) == shape[-1]
|
||||
else:
|
||||
pocketfft_type = pd.PocketFftType.C2C
|
||||
pocketfft_type = _C2C
|
||||
|
||||
assert np.issubdtype(dtype, np.complexfloating), dtype
|
||||
out_dtype = dtype
|
||||
pocketfft_dtype = (
|
||||
pd.PocketFftDtype.COMPLEX64
|
||||
if dtype == np.complex64 else pd.PocketFftDtype.COMPLEX128)
|
||||
|
||||
assert shape[-len(fft_lengths):] == fft_lengths, (shape, fft_lengths)
|
||||
out_shape = shape
|
||||
@ -90,54 +79,33 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
|
||||
|
||||
# Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the
|
||||
# C++ kernel to describe the FFT to perform.
|
||||
pd.PocketFftDescriptorStartShapeVector(builder, n)
|
||||
for d in reversed(shape if fft_type != FftType.IRFFT else out_shape):
|
||||
builder.PrependUint64(d)
|
||||
if flatbuffers_version_2:
|
||||
pocketfft_shape = builder.EndVector()
|
||||
else:
|
||||
pocketfft_shape = builder.EndVector(n)
|
||||
|
||||
pd.PocketFftDescriptorStartStridesInVector(builder, n)
|
||||
strides_in = []
|
||||
stride = dtype.itemsize
|
||||
for d in reversed(shape):
|
||||
builder.PrependUint64(stride)
|
||||
strides_in.append(stride)
|
||||
stride *= d
|
||||
if flatbuffers_version_2:
|
||||
strides_in = builder.EndVector()
|
||||
else:
|
||||
strides_in = builder.EndVector(n)
|
||||
pd.PocketFftDescriptorStartStridesOutVector(builder, n)
|
||||
|
||||
strides_out = []
|
||||
stride = out_dtype.itemsize
|
||||
for d in reversed(out_shape):
|
||||
builder.PrependUint64(stride)
|
||||
strides_out.append(stride)
|
||||
stride *= d
|
||||
if flatbuffers_version_2:
|
||||
strides_out = builder.EndVector()
|
||||
else:
|
||||
strides_out = builder.EndVector(n)
|
||||
|
||||
pd.PocketFftDescriptorStartAxesVector(builder, len(fft_lengths))
|
||||
for d in range(len(fft_lengths)):
|
||||
builder.PrependUint32(n - d - 1)
|
||||
if flatbuffers_version_2:
|
||||
axes = builder.EndVector()
|
||||
else:
|
||||
axes = builder.EndVector(len(fft_lengths))
|
||||
axes = [n - len(fft_lengths) + d for d in range(len(fft_lengths))]
|
||||
|
||||
scale = 1. if forward else (1. / np.prod(fft_lengths))
|
||||
pd.PocketFftDescriptorStart(builder)
|
||||
pd.PocketFftDescriptorAddDtype(builder, pocketfft_dtype)
|
||||
pd.PocketFftDescriptorAddFftType(builder, pocketfft_type)
|
||||
pd.PocketFftDescriptorAddShape(builder, pocketfft_shape)
|
||||
pd.PocketFftDescriptorAddStridesIn(builder, strides_in)
|
||||
pd.PocketFftDescriptorAddStridesOut(builder, strides_out)
|
||||
pd.PocketFftDescriptorAddAxes(builder, axes)
|
||||
pd.PocketFftDescriptorAddForward(builder, forward)
|
||||
pd.PocketFftDescriptorAddScale(builder, scale)
|
||||
descriptor = pd.PocketFftDescriptorEnd(builder)
|
||||
builder.Finish(descriptor)
|
||||
return builder.Output(), out_dtype, out_shape
|
||||
descriptor = _pocketfft.pocketfft_descriptor(
|
||||
shape=shape if fft_type != FftType.IRFFT else out_shape,
|
||||
is_double=is_double,
|
||||
fft_type=pocketfft_type,
|
||||
fft_lengths=fft_lengths,
|
||||
strides_in=list(reversed(strides_in)),
|
||||
strides_out=list(reversed(strides_out)),
|
||||
axes=axes,
|
||||
forward=forward,
|
||||
scale=scale)
|
||||
|
||||
return descriptor, out_dtype, out_shape
|
||||
|
||||
|
||||
def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
|
||||
|
@ -38,7 +38,7 @@ setup(
|
||||
author_email='jax-dev@google.com',
|
||||
packages=['jaxlib', 'jaxlib.xla_extension'],
|
||||
python_requires='>=3.7',
|
||||
install_requires=['scipy>=1.5', 'numpy>=1.19', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
|
||||
install_requires=['scipy>=1.5', 'numpy>=1.19', 'absl-py'],
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
classifiers=[
|
||||
|
Loading…
x
Reference in New Issue
Block a user