Split some submodules out of //jax under Bazel.

Add separate BUILD targets
* :version - for version.py
* _src/lib - wrapping the jaxlib shims.
* :util - for util.py
* :config - for config.py

PiperOrigin-RevId: 515307923
This commit is contained in:
Peter Hawkins 2023-03-09 05:26:58 -08:00 committed by jax authors
parent 5c914534f4
commit 0e05a7987f
8 changed files with 82 additions and 22 deletions

View File

@ -24,6 +24,7 @@ load(
"py_deps",
"py_library_providing_imports_info",
"pytype_library",
"pytype_strict_library",
)
package(default_visibility = [":internal"])
@ -101,6 +102,10 @@ py_library_providing_imports_info(
"third_party/**/*.py",
],
exclude = [
# TODO(phawkins): exclude these files after fixing up users.
# "_src/config.py",
# "_src/util.py",
"_src/lib/**",
"_src/test_util.py",
"*_test.py",
"**/*_test.py",
@ -122,20 +127,36 @@ py_library_providing_imports_info(
lib_rule = pytype_library,
pytype_srcs = glob(["_src/**/*.pyi"]),
visibility = ["//visibility:public"],
deps = select({
":enable_jaxlib_build": [":jaxlib_deps"],
"//conditions:default": [],
}) +
py_deps("numpy") + py_deps("scipy") + jax_extra_deps,
deps = [
":config",
":util",
":version",
"//jax/_src/lib",
] + py_deps("numpy") + py_deps("scipy") + jax_extra_deps,
)
py_library(
name = "jaxlib_deps",
pytype_library(
name = "config",
srcs = ["_src/config.py"],
deps = [
"//jaxlib",
"//jax/_src/lib",
],
)
pytype_library(
name = "util",
srcs = ["_src/util.py"],
deps = [
":config",
"//jax/_src/lib",
],
)
pytype_strict_library(
name = "version",
srcs = ["version.py"],
)
py_library_providing_imports_info(
name = "experimental",
srcs = glob(

View File

@ -135,7 +135,7 @@ class Config:
def config_with_absl(self):
# Run this before calling `app.run(main)` etc
import absl.flags as absl_FLAGS # noqa: F401
import absl.flags as absl_FLAGS # noqa: F401 # pytype: disable=import-error
from absl import app, flags as absl_flags # pytype: disable=import-error
self.use_absl = True
@ -166,7 +166,7 @@ class Config:
jax_argv = itertools.takewhile(lambda a: a != '--', sys.argv)
jax_argv = ['', *(a for a in jax_argv if a.startswith('--jax'))]
import absl.flags
import absl.flags # pytype: disable=import-error
self.config_with_absl()
absl.flags.FLAGS(jax_argv, known_only=True)
self.complete_absl_config(absl.flags)

34
jax/_src/lib/BUILD Normal file
View File

@ -0,0 +1,34 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//jaxlib:jax.bzl", "pytype_library")
package(default_visibility = ["//:__subpackages__"])
pytype_library(
name = "lib",
srcs = [
"__init__.py",
"mlir/__init__.py",
"mlir/dialects/__init__.py",
],
deps = [
"//jax:version",
] + select({
"//jax:enable_jaxlib_build": [
"//jaxlib",
],
"//conditions:default": [],
}),
)

View File

@ -42,7 +42,7 @@ except Exception as err:
# Checks the jaxlib version before importing anything else from jaxlib.
# Returns the jaxlib version string.
def check_jaxlib_version(jax_version: str, jaxlib_version: str,
minimum_jaxlib_version: str):
minimum_jaxlib_version: str) -> Tuple[int, ...]:
# Regex to match a dotted version prefix 0.1.23.456.789 of a PEP440 version.
# PEP440 allows a number of non-numeric suffixes, which we allow also.
# We currently do not allow an epoch.
@ -103,7 +103,7 @@ import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
# Only for the internal usage of the JAX developers, we expose a version
# number that can be used to perform changes without breaking the main
# branch on the Jax github.
xla_extension_version = getattr(xla_client, '_version', 0)
xla_extension_version: int = getattr(xla_client, '_version', 0)
import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error

View File

@ -20,7 +20,6 @@ import jaxlib.mlir.dialects.func as func
import jaxlib.mlir.dialects.ml_program as ml_program
import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
from jax.lib import xla_client
import jaxlib.mlir.dialects.stablehlo as stablehlo
# Alias that is set up to abstract away the transition from MHLO to StableHLO.

View File

@ -26,7 +26,7 @@ from typing import (Any, Callable, Generic, Iterable, Iterator, List,
import numpy as np
from jax._src.lib import xla_client as xc
from jax.config import config
from jax._src.config import config
logger = logging.getLogger(__name__)

View File

@ -61,6 +61,7 @@ from jax._src import pjit
from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_sparse
import jax.numpy as jnp
from jax._src.api_util import flatten_fun_nokwargs
from jax._src.lib import pytree
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters import pxla
@ -271,7 +272,7 @@ def spvalues_to_avals(
return tree_map(spvalue_to_aval, spvalues, is_leaf=_is_spvalue)
#------------------------------------------------------------------------------
# ------------------------------------------------------------------------------
# Implementation of sparsify() using tracers.
def popattr(obj: Any, name: str) -> Any:
@ -375,7 +376,7 @@ def _sparsify_with_tracer(fun):
return tree_unflatten(out_tree(), out)
return _wrapped
#------------------------------------------------------------------------------
# ------------------------------------------------------------------------------
# Implementation of sparsify() using a jaxpr interpreter.
def eval_sparse(
@ -440,7 +441,10 @@ def eval_sparse(
return safe_map(read, jaxpr.outvars)
def sparsify_raw(f):
def wrapped(spenv: SparsifyEnv, *spvalues: SparsifyValue, **params: Any) -> Tuple[Sequence[SparsifyValue], bool]:
def wrapped(
spenv: SparsifyEnv, *spvalues: SparsifyValue, **params: Any
) -> Tuple[Sequence[SparsifyValue], pytree.PyTreeDef]:
spvalues_flat, in_tree = tree_flatten(spvalues, is_leaf=_is_spvalue)
in_avals_flat = spvalues_to_avals(spenv, spvalues_flat)
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, params), in_tree)
@ -450,6 +454,7 @@ def sparsify_raw(f):
raise Exception("Internal: eval_sparse does not return expected number of arguments. "
"Got {result} for avals {out_avals_flat}")
return result, out_tree()
return wrapped
def _sparsify_with_interpreter(f):
@ -491,7 +496,7 @@ def sparsify(f, use_tracer=False):
return _sparsify_with_interpreter(f)
#------------------------------------------------------------------------------
# ------------------------------------------------------------------------------
# Sparse rules for various primitives
def _ensure_unique_indices(spenv, spvalue):
@ -680,7 +685,6 @@ def _reduce_sum_sparse(spenv, *spvalues, axes):
sparse_rules_bcoo[lax.reduce_sum_p] = _reduce_sum_sparse
def _gather_sparse_rule(spenv, *args, dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
operand, start_indices = spvalues_to_arrays(spenv, args)
@ -697,7 +701,7 @@ def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
# shared data & indices when generating the sparsified jaxpr. The
# current approach produces valid sparsified while loops, but they
# don't work in corner cases (see associated TODO in sparsify_test.py)
out_tree = None
out_tree: Optional[pytree.PyTreeDef] = None
@lu.wrap_init
def wrapped(*args_flat):
@ -718,6 +722,7 @@ def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat]
sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts)
assert out_tree is not None
return sp_jaxpr, out_tree
def _while_sparse(spenv, *spvalues, cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts):
@ -846,7 +851,7 @@ def _todense_sparse_rule(spenv, spvalue, *, tree):
sparse_rules_bcoo[sparse.todense_p] = _todense_sparse_rule
#------------------------------------------------------------------------------
# ------------------------------------------------------------------------------
# BCOO methods derived from sparsify
# defined here to avoid circular imports
@ -908,7 +913,7 @@ _bcoo_methods = {
for method, impl in _bcoo_methods.items():
setattr(BCOO, method, impl)
#------------------------------------------------------------------------------
# ------------------------------------------------------------------------------
# BCSR methods derived from sparsify
# defined here to avoid circular imports

View File

@ -26,6 +26,7 @@ load("@org_tensorflow//tensorflow/core/platform:build_config_root.bzl", _tf_cuda
cuda_library = _cuda_library
rocm_library = _rocm_library
pytype_library = native.py_library
pytype_strict_library = native.py_library
pytype_test = native.py_test
pyx_library = _pyx_library
pybind_extension = _pybind_extension