mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Split core.py and several files in an SCC with it into a separate Bazel build target.
PiperOrigin-RevId: 520192610
This commit is contained in:
parent
8c4fed6410
commit
c2d6fcc0e6
26
jax/BUILD
26
jax/BUILD
@ -96,7 +96,6 @@ py_library_providing_imports_info(
|
||||
"_src/array.py",
|
||||
"_src/callback.py",
|
||||
"_src/checkify.py",
|
||||
"_src/core.py",
|
||||
"_src/custom_batching.py",
|
||||
"_src/custom_derivatives.py",
|
||||
"_src/custom_transpose.py",
|
||||
@ -104,12 +103,9 @@ py_library_providing_imports_info(
|
||||
"_src/device_array.py",
|
||||
"_src/dispatch.py",
|
||||
"_src/dlpack.py",
|
||||
"_src/dtypes.py",
|
||||
"_src/errors.py",
|
||||
"_src/flatten_util.py",
|
||||
"_src/__init__.py",
|
||||
"_src/lax_reference.py",
|
||||
"_src/linear_util.py",
|
||||
"_src/maps.py",
|
||||
"_src/pjit.py",
|
||||
"_src/prng.py",
|
||||
@ -172,6 +168,7 @@ py_library_providing_imports_info(
|
||||
":custom_api_util",
|
||||
":config",
|
||||
":deprecations",
|
||||
":core",
|
||||
":effects",
|
||||
":environment_info",
|
||||
":lazy_loader",
|
||||
@ -215,6 +212,27 @@ pytype_strict_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "core",
|
||||
srcs = [
|
||||
"_src/core.py",
|
||||
"_src/dtypes.py",
|
||||
"_src/errors.py",
|
||||
"_src/linear_util.py",
|
||||
],
|
||||
deps = [
|
||||
":config",
|
||||
":effects",
|
||||
":pretty_printer",
|
||||
":source_info_util",
|
||||
":traceback_util",
|
||||
":tree_util",
|
||||
":typing",
|
||||
":util",
|
||||
"//jax/_src/lib",
|
||||
] + py_deps("ml_dtypes") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "custom_api_util",
|
||||
srcs = ["_src/custom_api_util.py"],
|
||||
|
@ -2387,6 +2387,26 @@ def _flat_axes_specs(abstracted_axes, *args, **kwargs
|
||||
return broadcast_prefix(abstracted_axes, args, ax_leaf)
|
||||
|
||||
|
||||
# TODO(phawkins): for some reason mypy cannot determine these overloads are
|
||||
# non-overlapping. Pytype is happy with them.
|
||||
@overload
|
||||
def make_jaxpr(fun: Callable, # type: ignore
|
||||
static_argnums: Union[int, Iterable[int]] = (),
|
||||
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
|
||||
return_shape: Literal[False] = ...,
|
||||
abstracted_axes: Optional[Any] = None,
|
||||
) -> Callable[..., core.ClosedJaxpr]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def make_jaxpr(fun: Callable, # type: ignore
|
||||
static_argnums: Union[int, Iterable[int]] = (),
|
||||
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
|
||||
return_shape: Literal[True] = ...,
|
||||
abstracted_axes: Optional[Any] = None,
|
||||
) -> Callable[..., Tuple[core.ClosedJaxpr, Any]]:
|
||||
...
|
||||
|
||||
def make_jaxpr(fun: Callable,
|
||||
static_argnums: Union[int, Iterable[int]] = (),
|
||||
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
|
||||
@ -3062,7 +3082,7 @@ def clear_backends():
|
||||
jax.lib.xla_bridge._backends = {}
|
||||
dispatch.xla_primitive_callable.cache_clear()
|
||||
pjit._pjit_lower_cached.cache_clear()
|
||||
pjit._create_pjit_jaxpr.cache_clear()
|
||||
pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error
|
||||
pjit._cpp_pjit_cache.clear()
|
||||
xc._xla.PjitFunctionCache.clear_all()
|
||||
|
||||
|
@ -40,8 +40,9 @@ from jax._src import dtypes
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import effects
|
||||
from jax._src.config import FLAGS, config
|
||||
from jax.errors import (ConcretizationTypeError, TracerArrayConversionError,
|
||||
TracerIntegerConversionError, UnexpectedTracerError)
|
||||
from jax._src.errors import (
|
||||
ConcretizationTypeError, TracerArrayConversionError,
|
||||
TracerIntegerConversionError, UnexpectedTracerError)
|
||||
from jax._src import linear_util as lu
|
||||
|
||||
from jax._src import source_info_util
|
||||
|
@ -600,6 +600,7 @@ def _check_sharding(aval, s):
|
||||
pjit.pjit_check_aval_sharding(
|
||||
(s,), (aval,), "device_put args", allow_uneven_sharding=False)
|
||||
|
||||
assert isinstance(aval, core.ShapedArray), aval
|
||||
s.shard_shape(aval.shape) # should raise an Error if incompatible
|
||||
|
||||
|
||||
|
@ -443,7 +443,7 @@ for t in device_array.device_array_types:
|
||||
shard_arg_handlers[t] = shard_device_array
|
||||
|
||||
|
||||
def batched_device_put(aval: core.AbstractValue,
|
||||
def batched_device_put(aval: core.ShapedArray,
|
||||
sharding: jax.sharding.Sharding, xs: Sequence[Any],
|
||||
devices: Sequence[jax.Device], committed: bool = True):
|
||||
from jax._src import array
|
||||
@ -1794,7 +1794,7 @@ def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0):
|
||||
devices = xb.get_backend(backend).get_default_device_assignment(nrep)
|
||||
assert nrep == len(devices)
|
||||
|
||||
aval = xla.abstractify(val) # type: ShapedArray
|
||||
aval = xla.abstractify(val)
|
||||
if in_axis is not None:
|
||||
replicated_aval = aval.update(shape=(axis_size,) + aval.shape)
|
||||
else:
|
||||
|
@ -191,7 +191,7 @@ canonicalize_dtype_handlers.update(
|
||||
canonicalize_dtype_handlers[core.Token] = identity
|
||||
canonicalize_dtype_handlers[core.DArray] = identity
|
||||
|
||||
def abstractify(x) -> core.AbstractValue:
|
||||
def abstractify(x) -> Any:
|
||||
typ = type(x)
|
||||
aval_fn = pytype_aval_mappings.get(typ)
|
||||
if aval_fn: return aval_fn(x)
|
||||
|
@ -67,8 +67,8 @@ from functools import partial
|
||||
from typing import Any, Tuple, Callable, Optional, NamedTuple
|
||||
import weakref
|
||||
|
||||
from jax.tree_util import tree_map
|
||||
from jax.config import config
|
||||
from jax._src.tree_util import tree_map
|
||||
from jax._src.config import config
|
||||
from jax._src import core
|
||||
from jax._src import traceback_util
|
||||
from jax._src.util import curry
|
||||
|
Loading…
x
Reference in New Issue
Block a user