rocm_jax/jaxlib/pywrap.bzl
Peter Hawkins db11efab3b Migrate jaxlib to use a single common .so file for all C++ dependencies.
The idea is to move all of the jaxlib contents into a single .so file,
and have all of the other Python extensions be tiny stubs that reexport
part of the larger .so file.

This has two main benefits:
* it reduces the size of the jaxlib wheel, by about 70-80MB when
  installed. The benefit of the change is that it avoid duplication
  between the MLIR CAPI code and the copy of MLIR in XLA.
* it gives us flexibility to split and merge Python extensions as we see
  fit.

Issue https://github.com/jax-ml/jax/issues/11225

PiperOrigin-RevId: 744855997
2025-04-07 14:44:08 -07:00

90 lines
2.8 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrappers around pywrap rules for JAX."""
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
load(
"@xla//third_party/py/rules_pywrap:pywrap.impl.bzl",
"pybind_extension",
_pywrap_binaries = "pywrap_binaries",
_pywrap_library = "pywrap_library",
)
pywrap_library = _pywrap_library
pywrap_binaries = _pywrap_binaries
def nanobind_pywrap_extension(
name,
srcs = [],
deps = [],
pytype_srcs = [],
pytype_deps = [],
copts = [],
linkopts = [],
visibility = None):
# buildifier: disable=function-docstring-args
"Python extension rule using nanobind and the pywrap rules."
module_name = name
lib_name = name + "_pywrap_library"
src_cc_name = name + "_pywrap_stub.c"
# We put the entire contents of the extension in a single cc_library, which will become part of
# the common pywrap library. All the contents of all extensions will end up in the common
# library.
native.cc_library(
name = lib_name,
srcs = srcs,
copts = copts,
deps = deps,
local_defines = [
"PyInit_{}=Wrapped_PyInit_{}".format(module_name, module_name),
],
visibility = ["//visibility:private"],
)
# We build a small stub library as the extension that forwards to the PyInit_... symbol from the
# common pywrap library.
expand_template(
name = name + "_pywrap_stub",
testonly = True,
out = src_cc_name,
substitutions = {
"@MODULE_NAME@": module_name,
},
template = "//jaxlib:pyinit_stub.c",
visibility = ["//visibility:private"],
)
# Despite its name "pybind_extension" has nothing to do with pybind. It is the Python extension
# rule from the pywrap rules.
pybind_extension(
name = name,
srcs = [src_cc_name],
deps = [":" + lib_name],
linkopts = linkopts,
visibility = visibility,
default_deps = [],
common_lib_packages = [
"jaxlib",
],
)
# Create a py_library with the type stubs as data, on which wheel builds can depend.
native.py_library(
name = name + "_type_stubs",
data = pytype_srcs,
deps = pytype_deps,
)