mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Prune some exports from jax.experimental.pjit.
jax.experimental.pjit is deprecated in its entirety (use "jit" instead), and experimental APIs have no stability promises. PiperOrigin-RevId: 552903601
This commit is contained in:
parent
2e042b6195
commit
0116d196a7
@ -503,6 +503,7 @@ pytype_strict_library(
|
||||
pytype_strict_library(
|
||||
name = "sharding_impls",
|
||||
srcs = ["_src/sharding_impls.py"],
|
||||
visibility = [":internal"] + jax_visibility("sharding_impls"),
|
||||
deps = [
|
||||
":mesh",
|
||||
":op_shardings",
|
||||
|
@ -2262,7 +2262,7 @@ def get_gspmd_shardings_from_executable(
|
||||
num_in_avals: int, num_out_avals: int
|
||||
) -> tuple[Sequence[sharding_impls.XLACompatibleSharding],
|
||||
Sequence[sharding_impls.XLACompatibleSharding]]:
|
||||
from jax.experimental import pjit
|
||||
from jax._src import pjit
|
||||
|
||||
# When the device assignment only has 1 device, SPMD partitioner will not run.
|
||||
# Hence the op shardings will not be set on the `hlo_module`. In that case,
|
||||
@ -2272,7 +2272,7 @@ def get_gspmd_shardings_from_executable(
|
||||
ss = sharding_impls.SingleDeviceSharding(device_assignment[0])
|
||||
return [ss] * num_in_avals, [ss] * num_out_avals
|
||||
|
||||
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
|
||||
in_op_shardings, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable)
|
||||
|
||||
in_shardings_xla = [sharding_impls.GSPMDSharding(device_assignment, i)
|
||||
for i in in_op_shardings]
|
||||
@ -2295,9 +2295,9 @@ def _get_mesh_pspec_shardings_from_executable(
|
||||
xla_executable, mesh: Mesh
|
||||
) -> tuple[Sequence[sharding_impls.NamedSharding],
|
||||
Sequence[sharding_impls.NamedSharding]]:
|
||||
from jax.experimental import pjit
|
||||
from jax._src import pjit
|
||||
|
||||
in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh)
|
||||
in_pspec, out_pspec = pjit.get_pspec_from_executable(xla_executable, mesh)
|
||||
return ([sharding_impls.NamedSharding(mesh, i) for i in in_pspec],
|
||||
[sharding_impls.NamedSharding(mesh, o) for o in out_pspec])
|
||||
|
||||
|
@ -1973,7 +1973,7 @@ def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[Parti
|
||||
return [get_single_pspec(p) for p in ppspec]
|
||||
|
||||
|
||||
def _get_op_sharding_from_executable(
|
||||
def get_op_sharding_from_executable(
|
||||
executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]:
|
||||
in_op_shardings: list[xc.OpSharding] = []
|
||||
parameter_shardings_from_xla = executable.get_parameter_shardings()
|
||||
@ -1989,7 +1989,7 @@ def _get_op_sharding_from_executable(
|
||||
|
||||
|
||||
def _get_ppspec_from_executable(executable, mesh) -> tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:
|
||||
input_op_shardings, output_op_sharding = _get_op_sharding_from_executable(
|
||||
input_op_shardings, output_op_sharding = get_op_sharding_from_executable(
|
||||
executable
|
||||
)
|
||||
in_ppspec: list[ParsedPartitionSpec] = []
|
||||
@ -2002,7 +2002,7 @@ def _get_ppspec_from_executable(executable, mesh) -> tuple[Sequence[ParsedPartit
|
||||
return in_ppspec, out_ppspec
|
||||
|
||||
|
||||
def _get_pspec_from_executable(
|
||||
def get_pspec_from_executable(
|
||||
executable, mesh: pxla.Mesh
|
||||
) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]:
|
||||
in_ppspec, out_ppspec = _get_ppspec_from_executable(executable, mesh)
|
||||
|
@ -15,7 +15,6 @@
|
||||
# flake8: noqa
|
||||
|
||||
from jax._src.pjit import (
|
||||
hashable_pytree as hashable_pytree,
|
||||
pjit as pjit,
|
||||
pjit_p as pjit_p,
|
||||
with_sharding_constraint as with_sharding_constraint,
|
||||
@ -23,13 +22,6 @@ from jax._src.pjit import (
|
||||
from jax._src.sharding_impls import (
|
||||
AUTO as AUTO,
|
||||
UNSPECIFIED as _UNSPECIFIED,
|
||||
ParsedPartitionSpec as ParsedPartitionSpec,
|
||||
get_array_mapping as get_array_mapping,
|
||||
prepare_axis_resources as _prepare_axis_resources,
|
||||
parse_flatten_op_sharding as parse_flatten_op_sharding,
|
||||
)
|
||||
|
||||
from jax._src.pjit import (_get_op_sharding_from_executable,
|
||||
_get_pspec_from_executable, _pjit_lower_cached,
|
||||
_pjit_lower, _pjit_jaxpr,
|
||||
_process_in_axis_resources)
|
||||
from jax._src.pjit import (_pjit_lower_cached, _pjit_lower)
|
||||
|
Loading…
x
Reference in New Issue
Block a user