Prune some exports from jax.experimental.pjit.

jax.experimental.pjit is deprecated in its entirety (use "jit" instead), and experimental APIs have no stability promises.

PiperOrigin-RevId: 552903601
This commit is contained in:
Peter Hawkins 2023-08-01 13:26:43 -07:00 committed by jax authors
parent 2e042b6195
commit 0116d196a7
4 changed files with 9 additions and 16 deletions

View File

@ -503,6 +503,7 @@ pytype_strict_library(
pytype_strict_library(
name = "sharding_impls",
srcs = ["_src/sharding_impls.py"],
visibility = [":internal"] + jax_visibility("sharding_impls"),
deps = [
":mesh",
":op_shardings",

View File

@ -2262,7 +2262,7 @@ def get_gspmd_shardings_from_executable(
num_in_avals: int, num_out_avals: int
) -> tuple[Sequence[sharding_impls.XLACompatibleSharding],
Sequence[sharding_impls.XLACompatibleSharding]]:
from jax.experimental import pjit
from jax._src import pjit
# When the device assignment only has 1 device, SPMD partitioner will not run.
# Hence the op shardings will not be set on the `hlo_module`. In that case,
@ -2272,7 +2272,7 @@ def get_gspmd_shardings_from_executable(
ss = sharding_impls.SingleDeviceSharding(device_assignment[0])
return [ss] * num_in_avals, [ss] * num_out_avals
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
in_op_shardings, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable)
in_shardings_xla = [sharding_impls.GSPMDSharding(device_assignment, i)
for i in in_op_shardings]
@ -2295,9 +2295,9 @@ def _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh: Mesh
) -> tuple[Sequence[sharding_impls.NamedSharding],
Sequence[sharding_impls.NamedSharding]]:
from jax.experimental import pjit
from jax._src import pjit
in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh)
in_pspec, out_pspec = pjit.get_pspec_from_executable(xla_executable, mesh)
return ([sharding_impls.NamedSharding(mesh, i) for i in in_pspec],
[sharding_impls.NamedSharding(mesh, o) for o in out_pspec])

View File

@ -1973,7 +1973,7 @@ def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[Parti
return [get_single_pspec(p) for p in ppspec]
def _get_op_sharding_from_executable(
def get_op_sharding_from_executable(
executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]:
in_op_shardings: list[xc.OpSharding] = []
parameter_shardings_from_xla = executable.get_parameter_shardings()
@ -1989,7 +1989,7 @@ def _get_op_sharding_from_executable(
def _get_ppspec_from_executable(executable, mesh) -> tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]:
input_op_shardings, output_op_sharding = _get_op_sharding_from_executable(
input_op_shardings, output_op_sharding = get_op_sharding_from_executable(
executable
)
in_ppspec: list[ParsedPartitionSpec] = []
@ -2002,7 +2002,7 @@ def _get_ppspec_from_executable(executable, mesh) -> tuple[Sequence[ParsedPartit
return in_ppspec, out_ppspec
def _get_pspec_from_executable(
def get_pspec_from_executable(
executable, mesh: pxla.Mesh
) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]:
in_ppspec, out_ppspec = _get_ppspec_from_executable(executable, mesh)

View File

@ -15,7 +15,6 @@
# flake8: noqa
from jax._src.pjit import (
hashable_pytree as hashable_pytree,
pjit as pjit,
pjit_p as pjit_p,
with_sharding_constraint as with_sharding_constraint,
@ -23,13 +22,6 @@ from jax._src.pjit import (
from jax._src.sharding_impls import (
AUTO as AUTO,
UNSPECIFIED as _UNSPECIFIED,
ParsedPartitionSpec as ParsedPartitionSpec,
get_array_mapping as get_array_mapping,
prepare_axis_resources as _prepare_axis_resources,
parse_flatten_op_sharding as parse_flatten_op_sharding,
)
from jax._src.pjit import (_get_op_sharding_from_executable,
_get_pspec_from_executable, _pjit_lower_cached,
_pjit_lower, _pjit_jaxpr,
_process_in_axis_resources)
from jax._src.pjit import (_pjit_lower_cached, _pjit_lower)