Drop flatbuffers as a Python dependency of JAX.

Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457460347
This commit is contained in:
Peter Hawkins 2022-06-27 06:13:36 -07:00 committed by jax authors
parent 93f5113c93
commit efefeac450
7 changed files with 65 additions and 71 deletions

View File

@ -180,7 +180,6 @@ def prepare_wheel(sources_path):
copy_to_jaxlib(f"__main__/jaxlib/_lapack.{pyext}")
copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py")
copy_to_jaxlib(f"__main__/jaxlib/_pocketfft.{pyext}")
copy_to_jaxlib("__main__/jaxlib/pocketfft_flatbuffers_py_generated.py")
copy_to_jaxlib("__main__/jaxlib/pocketfft.py")
copy_to_jaxlib("__main__/jaxlib/gpu_prng.py")
copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py")

View File

@ -1,6 +1,5 @@
cloudpickle
colorama>=0.4.4
flatbuffers==2.0
# TODO(jakevdp): fix use of deprecated NEAREST resampling for more recent pillow.
pillow>=8.3.1,<9.1.0
pytest-benchmark

View File

@ -17,7 +17,6 @@
load(
"//jaxlib:jax.bzl",
"flatbuffer_cc_library",
"flatbuffer_py_library",
"pybind_extension",
)
@ -86,7 +85,6 @@ py_library(
":_lapack",
":_pocketfft",
":cpu_feature_guard",
":pocketfft_flatbuffers_py",
],
)
@ -148,11 +146,6 @@ flatbuffer_cc_library(
srcs = ["pocketfft.fbs"],
)
flatbuffer_py_library(
name = "pocketfft_flatbuffers_py",
srcs = ["pocketfft.fbs"],
)
cc_library(
name = "pocketfft_kernels",
srcs = ["pocketfft_kernels.cc"],
@ -178,7 +171,9 @@ pybind_extension(
module_name = "_pocketfft",
deps = [
":kernel_pybind11_helpers",
":pocketfft_flatbuffers_cc",
":pocketfft_kernels",
"@flatbuffers//:runtime_cc",
"@pybind11",
],
)

View File

@ -18,7 +18,7 @@ load("@org_tensorflow//tensorflow/core/platform/default:build_config.bzl", _pyx_
load("@org_tensorflow//tensorflow:tensorflow.bzl", _pybind_extension = "pybind_extension")
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
load("@flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library", _flatbuffer_py_library = "flatbuffer_py_library")
load("@flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
# Explicitly re-exports names to avoid "unused variable" warnings from .bzl
# lint tools.
@ -30,7 +30,6 @@ pybind_extension = _pybind_extension
if_cuda_is_configured = _if_cuda_is_configured
if_rocm_is_configured = _if_rocm_is_configured
flatbuffer_cc_library = _flatbuffer_cc_library
flatbuffer_py_library = _flatbuffer_py_library
def py_extension(name, srcs, copts, deps):
pybind_extension(name, srcs = srcs, copts = copts, deps = deps, module_name = name)

View File

@ -14,21 +14,55 @@ limitations under the License.
==============================================================================*/
#include <complex>
#include <vector>
#include "jaxlib/kernel_pybind11_helpers.h"
#include "jaxlib/pocketfft_generated.h"
#include "jaxlib/pocketfft_kernels.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
namespace py = pybind11;
namespace jax {
namespace {
pybind11::dict Registrations() {
py::bytes BuildPocketFftDescriptor(const std::vector<uint64_t>& shape,
bool is_double, int fft_type,
const std::vector<uint64_t>& fft_lengths,
const std::vector<uint64_t>& strides_in,
const std::vector<uint64_t>& strides_out,
const std::vector<uint32_t>& axes,
bool forward, double scale) {
PocketFftDescriptorT descriptor;
descriptor.shape = shape;
descriptor.fft_type = static_cast<PocketFftType>(fft_type);
descriptor.dtype =
is_double ? PocketFftDtype_COMPLEX128 : PocketFftDtype_COMPLEX64;
descriptor.strides_in = strides_in;
descriptor.strides_out = strides_out;
descriptor.axes = axes;
descriptor.forward = forward;
descriptor.scale = scale;
flatbuffers::FlatBufferBuilder fbb;
fbb.Finish(PocketFftDescriptor::Pack(fbb, &descriptor));
return py::bytes(reinterpret_cast<char*>(fbb.GetBufferPointer()),
fbb.GetSize());
}
py::dict Registrations() {
pybind11::dict dict;
dict["pocketfft"] = EncapsulateFunction(PocketFft);
return dict;
}
PYBIND11_MODULE(_pocketfft, m) { m.def("registrations", &Registrations); }
PYBIND11_MODULE(_pocketfft, m) {
m.def("registrations", &Registrations);
m.def("pocketfft_descriptor", &BuildPocketFftDescriptor, py::arg("shape"),
py::arg("is_double"), py::arg("fft_type"), py::arg("fft_lengths"),
py::arg("strides_in"), py::arg("strides_out"), py::arg("axes"),
py::arg("forward"), py::arg("scale"));
}
} // namespace
} // namespace jax

View File

@ -13,8 +13,6 @@
# limitations under the License.
import jax
# flatbuffers needs importlib.util but fails to import it itself.
import importlib.util # noqa: F401
from typing import List
import jaxlib.mlir.ir as ir
@ -23,10 +21,8 @@ import jaxlib.mlir.dialects.mhlo as mhlo
from .mhlo_helpers import custom_call
from . import _pocketfft
from . import pocketfft_flatbuffers_py_generated as pd
import numpy as np
import flatbuffers
from jaxlib import xla_client
for _name, _value in _pocketfft.registrations().items():
@ -34,8 +30,10 @@ for _name, _value in _pocketfft.registrations().items():
FftType = xla_client.FftType
flatbuffers_version_2 = hasattr(flatbuffers, "__version__")
_C2C = 0
_C2R = 1
_R2C = 2
def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
fft_lengths: List[int]) -> bytes:
@ -43,43 +41,34 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
assert len(fft_lengths) >= 1
assert len(fft_lengths) <= n, (fft_lengths, n)
builder = flatbuffers.Builder(128)
forward = fft_type in (FftType.FFT, FftType.RFFT)
is_double = np.finfo(dtype).dtype == np.float64
if fft_type == FftType.RFFT:
pocketfft_type = pd.PocketFftType.R2C
pocketfft_type = _R2C
assert dtype in (np.float32, np.float64), dtype
out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128)
pocketfft_dtype = (
pd.PocketFftDtype.COMPLEX64
if dtype == np.float32 else pd.PocketFftDtype.COMPLEX128)
assert shape[-len(fft_lengths):] == fft_lengths, (shape, fft_lengths)
out_shape = list(shape)
out_shape[-1] = out_shape[-1] // 2 + 1
elif fft_type == FftType.IRFFT:
pocketfft_type = pd.PocketFftType.C2R
pocketfft_type = _C2R
assert np.issubdtype(dtype, np.complexfloating), dtype
out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64)
pocketfft_dtype = (
pd.PocketFftDtype.COMPLEX64
if dtype == np.complex64 else pd.PocketFftDtype.COMPLEX128)
assert shape[-len(fft_lengths):-1] == fft_lengths[:-1]
out_shape = list(shape)
out_shape[-1] = fft_lengths[-1]
assert (out_shape[-1] // 2 + 1) == shape[-1]
else:
pocketfft_type = pd.PocketFftType.C2C
pocketfft_type = _C2C
assert np.issubdtype(dtype, np.complexfloating), dtype
out_dtype = dtype
pocketfft_dtype = (
pd.PocketFftDtype.COMPLEX64
if dtype == np.complex64 else pd.PocketFftDtype.COMPLEX128)
assert shape[-len(fft_lengths):] == fft_lengths, (shape, fft_lengths)
out_shape = shape
@ -90,54 +79,33 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
# Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the
# C++ kernel to describe the FFT to perform.
pd.PocketFftDescriptorStartShapeVector(builder, n)
for d in reversed(shape if fft_type != FftType.IRFFT else out_shape):
builder.PrependUint64(d)
if flatbuffers_version_2:
pocketfft_shape = builder.EndVector()
else:
pocketfft_shape = builder.EndVector(n)
pd.PocketFftDescriptorStartStridesInVector(builder, n)
strides_in = []
stride = dtype.itemsize
for d in reversed(shape):
builder.PrependUint64(stride)
strides_in.append(stride)
stride *= d
if flatbuffers_version_2:
strides_in = builder.EndVector()
else:
strides_in = builder.EndVector(n)
pd.PocketFftDescriptorStartStridesOutVector(builder, n)
strides_out = []
stride = out_dtype.itemsize
for d in reversed(out_shape):
builder.PrependUint64(stride)
strides_out.append(stride)
stride *= d
if flatbuffers_version_2:
strides_out = builder.EndVector()
else:
strides_out = builder.EndVector(n)
pd.PocketFftDescriptorStartAxesVector(builder, len(fft_lengths))
for d in range(len(fft_lengths)):
builder.PrependUint32(n - d - 1)
if flatbuffers_version_2:
axes = builder.EndVector()
else:
axes = builder.EndVector(len(fft_lengths))
axes = [n - len(fft_lengths) + d for d in range(len(fft_lengths))]
scale = 1. if forward else (1. / np.prod(fft_lengths))
pd.PocketFftDescriptorStart(builder)
pd.PocketFftDescriptorAddDtype(builder, pocketfft_dtype)
pd.PocketFftDescriptorAddFftType(builder, pocketfft_type)
pd.PocketFftDescriptorAddShape(builder, pocketfft_shape)
pd.PocketFftDescriptorAddStridesIn(builder, strides_in)
pd.PocketFftDescriptorAddStridesOut(builder, strides_out)
pd.PocketFftDescriptorAddAxes(builder, axes)
pd.PocketFftDescriptorAddForward(builder, forward)
pd.PocketFftDescriptorAddScale(builder, scale)
descriptor = pd.PocketFftDescriptorEnd(builder)
builder.Finish(descriptor)
return builder.Output(), out_dtype, out_shape
descriptor = _pocketfft.pocketfft_descriptor(
shape=shape if fft_type != FftType.IRFFT else out_shape,
is_double=is_double,
fft_type=pocketfft_type,
fft_lengths=fft_lengths,
strides_in=list(reversed(strides_in)),
strides_out=list(reversed(strides_out)),
axes=axes,
forward=forward,
scale=scale)
return descriptor, out_dtype, out_shape
def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):

View File

@ -38,7 +38,7 @@ setup(
author_email='jax-dev@google.com',
packages=['jaxlib', 'jaxlib.xla_extension'],
python_requires='>=3.7',
install_requires=['scipy>=1.5', 'numpy>=1.19', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
install_requires=['scipy>=1.5', 'numpy>=1.19', 'absl-py'],
url='https://github.com/google/jax',
license='Apache-2.0',
classifiers=[