mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Split some submodules out of //jax under Bazel.
Add separate BUILD targets * :version - for version.py * _src/lib - wrapping the jaxlib shims. * :util - for util.py * :config - for config.py PiperOrigin-RevId: 515307923
This commit is contained in:
parent
5c914534f4
commit
0e05a7987f
37
jax/BUILD
37
jax/BUILD
@ -24,6 +24,7 @@ load(
|
||||
"py_deps",
|
||||
"py_library_providing_imports_info",
|
||||
"pytype_library",
|
||||
"pytype_strict_library",
|
||||
)
|
||||
|
||||
package(default_visibility = [":internal"])
|
||||
@ -101,6 +102,10 @@ py_library_providing_imports_info(
|
||||
"third_party/**/*.py",
|
||||
],
|
||||
exclude = [
|
||||
# TODO(phawkins): exclude these files after fixing up users.
|
||||
# "_src/config.py",
|
||||
# "_src/util.py",
|
||||
"_src/lib/**",
|
||||
"_src/test_util.py",
|
||||
"*_test.py",
|
||||
"**/*_test.py",
|
||||
@ -122,20 +127,36 @@ py_library_providing_imports_info(
|
||||
lib_rule = pytype_library,
|
||||
pytype_srcs = glob(["_src/**/*.pyi"]),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
":enable_jaxlib_build": [":jaxlib_deps"],
|
||||
"//conditions:default": [],
|
||||
}) +
|
||||
py_deps("numpy") + py_deps("scipy") + jax_extra_deps,
|
||||
deps = [
|
||||
":config",
|
||||
":util",
|
||||
":version",
|
||||
"//jax/_src/lib",
|
||||
] + py_deps("numpy") + py_deps("scipy") + jax_extra_deps,
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "jaxlib_deps",
|
||||
pytype_library(
|
||||
name = "config",
|
||||
srcs = ["_src/config.py"],
|
||||
deps = [
|
||||
"//jaxlib",
|
||||
"//jax/_src/lib",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "util",
|
||||
srcs = ["_src/util.py"],
|
||||
deps = [
|
||||
":config",
|
||||
"//jax/_src/lib",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "version",
|
||||
srcs = ["version.py"],
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
name = "experimental",
|
||||
srcs = glob(
|
||||
|
@ -135,7 +135,7 @@ class Config:
|
||||
|
||||
def config_with_absl(self):
|
||||
# Run this before calling `app.run(main)` etc
|
||||
import absl.flags as absl_FLAGS # noqa: F401
|
||||
import absl.flags as absl_FLAGS # noqa: F401 # pytype: disable=import-error
|
||||
from absl import app, flags as absl_flags # pytype: disable=import-error
|
||||
|
||||
self.use_absl = True
|
||||
@ -166,7 +166,7 @@ class Config:
|
||||
jax_argv = itertools.takewhile(lambda a: a != '--', sys.argv)
|
||||
jax_argv = ['', *(a for a in jax_argv if a.startswith('--jax'))]
|
||||
|
||||
import absl.flags
|
||||
import absl.flags # pytype: disable=import-error
|
||||
self.config_with_absl()
|
||||
absl.flags.FLAGS(jax_argv, known_only=True)
|
||||
self.complete_absl_config(absl.flags)
|
||||
|
34
jax/_src/lib/BUILD
Normal file
34
jax/_src/lib/BUILD
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//jaxlib:jax.bzl", "pytype_library")
|
||||
|
||||
package(default_visibility = ["//:__subpackages__"])
|
||||
|
||||
pytype_library(
|
||||
name = "lib",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"mlir/__init__.py",
|
||||
"mlir/dialects/__init__.py",
|
||||
],
|
||||
deps = [
|
||||
"//jax:version",
|
||||
] + select({
|
||||
"//jax:enable_jaxlib_build": [
|
||||
"//jaxlib",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
)
|
@ -42,7 +42,7 @@ except Exception as err:
|
||||
# Checks the jaxlib version before importing anything else from jaxlib.
|
||||
# Returns the jaxlib version string.
|
||||
def check_jaxlib_version(jax_version: str, jaxlib_version: str,
|
||||
minimum_jaxlib_version: str):
|
||||
minimum_jaxlib_version: str) -> Tuple[int, ...]:
|
||||
# Regex to match a dotted version prefix 0.1.23.456.789 of a PEP440 version.
|
||||
# PEP440 allows a number of non-numeric suffixes, which we allow also.
|
||||
# We currently do not allow an epoch.
|
||||
@ -103,7 +103,7 @@ import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
|
||||
# Only for the internal usage of the JAX developers, we expose a version
|
||||
# number that can be used to perform changes without breaking the main
|
||||
# branch on the Jax github.
|
||||
xla_extension_version = getattr(xla_client, '_version', 0)
|
||||
xla_extension_version: int = getattr(xla_client, '_version', 0)
|
||||
|
||||
import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error
|
||||
|
||||
|
@ -20,7 +20,6 @@ import jaxlib.mlir.dialects.func as func
|
||||
import jaxlib.mlir.dialects.ml_program as ml_program
|
||||
import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
|
||||
|
||||
from jax.lib import xla_client
|
||||
import jaxlib.mlir.dialects.stablehlo as stablehlo
|
||||
|
||||
# Alias that is set up to abstract away the transition from MHLO to StableHLO.
|
||||
|
@ -26,7 +26,7 @@ from typing import (Any, Callable, Generic, Iterable, Iterator, List,
|
||||
import numpy as np
|
||||
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.config import config
|
||||
from jax._src.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -61,6 +61,7 @@ from jax._src import pjit
|
||||
from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_sparse
|
||||
import jax.numpy as jnp
|
||||
from jax._src.api_util import flatten_fun_nokwargs
|
||||
from jax._src.lib import pytree
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
@ -271,7 +272,7 @@ def spvalues_to_avals(
|
||||
return tree_map(spvalue_to_aval, spvalues, is_leaf=_is_spvalue)
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# ------------------------------------------------------------------------------
|
||||
# Implementation of sparsify() using tracers.
|
||||
|
||||
def popattr(obj: Any, name: str) -> Any:
|
||||
@ -375,7 +376,7 @@ def _sparsify_with_tracer(fun):
|
||||
return tree_unflatten(out_tree(), out)
|
||||
return _wrapped
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# ------------------------------------------------------------------------------
|
||||
# Implementation of sparsify() using a jaxpr interpreter.
|
||||
|
||||
def eval_sparse(
|
||||
@ -440,7 +441,10 @@ def eval_sparse(
|
||||
return safe_map(read, jaxpr.outvars)
|
||||
|
||||
def sparsify_raw(f):
|
||||
def wrapped(spenv: SparsifyEnv, *spvalues: SparsifyValue, **params: Any) -> Tuple[Sequence[SparsifyValue], bool]:
|
||||
|
||||
def wrapped(
|
||||
spenv: SparsifyEnv, *spvalues: SparsifyValue, **params: Any
|
||||
) -> Tuple[Sequence[SparsifyValue], pytree.PyTreeDef]:
|
||||
spvalues_flat, in_tree = tree_flatten(spvalues, is_leaf=_is_spvalue)
|
||||
in_avals_flat = spvalues_to_avals(spenv, spvalues_flat)
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, params), in_tree)
|
||||
@ -450,6 +454,7 @@ def sparsify_raw(f):
|
||||
raise Exception("Internal: eval_sparse does not return expected number of arguments. "
|
||||
"Got {result} for avals {out_avals_flat}")
|
||||
return result, out_tree()
|
||||
|
||||
return wrapped
|
||||
|
||||
def _sparsify_with_interpreter(f):
|
||||
@ -491,7 +496,7 @@ def sparsify(f, use_tracer=False):
|
||||
return _sparsify_with_interpreter(f)
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# ------------------------------------------------------------------------------
|
||||
# Sparse rules for various primitives
|
||||
|
||||
def _ensure_unique_indices(spenv, spvalue):
|
||||
@ -680,7 +685,6 @@ def _reduce_sum_sparse(spenv, *spvalues, axes):
|
||||
sparse_rules_bcoo[lax.reduce_sum_p] = _reduce_sum_sparse
|
||||
|
||||
|
||||
|
||||
def _gather_sparse_rule(spenv, *args, dimension_numbers, slice_sizes, unique_indices,
|
||||
indices_are_sorted, mode, fill_value):
|
||||
operand, start_indices = spvalues_to_arrays(spenv, args)
|
||||
@ -697,7 +701,7 @@ def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
|
||||
# shared data & indices when generating the sparsified jaxpr. The
|
||||
# current approach produces valid sparsified while loops, but they
|
||||
# don't work in corner cases (see associated TODO in sparsify_test.py)
|
||||
out_tree = None
|
||||
out_tree: Optional[pytree.PyTreeDef] = None
|
||||
|
||||
@lu.wrap_init
|
||||
def wrapped(*args_flat):
|
||||
@ -718,6 +722,7 @@ def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
|
||||
avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat]
|
||||
sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
|
||||
sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts)
|
||||
assert out_tree is not None
|
||||
return sp_jaxpr, out_tree
|
||||
|
||||
def _while_sparse(spenv, *spvalues, cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts):
|
||||
@ -846,7 +851,7 @@ def _todense_sparse_rule(spenv, spvalue, *, tree):
|
||||
sparse_rules_bcoo[sparse.todense_p] = _todense_sparse_rule
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# ------------------------------------------------------------------------------
|
||||
# BCOO methods derived from sparsify
|
||||
# defined here to avoid circular imports
|
||||
|
||||
@ -908,7 +913,7 @@ _bcoo_methods = {
|
||||
for method, impl in _bcoo_methods.items():
|
||||
setattr(BCOO, method, impl)
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# ------------------------------------------------------------------------------
|
||||
# BCSR methods derived from sparsify
|
||||
# defined here to avoid circular imports
|
||||
|
||||
|
@ -26,6 +26,7 @@ load("@org_tensorflow//tensorflow/core/platform:build_config_root.bzl", _tf_cuda
|
||||
cuda_library = _cuda_library
|
||||
rocm_library = _rocm_library
|
||||
pytype_library = native.py_library
|
||||
pytype_strict_library = native.py_library
|
||||
pytype_test = native.py_test
|
||||
pyx_library = _pyx_library
|
||||
pybind_extension = _pybind_extension
|
||||
|
Loading…
x
Reference in New Issue
Block a user