From 0116d196a77fc033e2b453a56af29e1253cac6c2 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 1 Aug 2023 13:26:43 -0700 Subject: [PATCH] Prune some exports from jax.experimental.pjit. jax.experimental.pjit is deprecated in its entirety (use "jit" instead), and experimental APIs have no stability promises. PiperOrigin-RevId: 552903601 --- jax/BUILD | 1 + jax/_src/interpreters/pxla.py | 8 ++++---- jax/_src/pjit.py | 6 +++--- jax/experimental/pjit.py | 10 +--------- 4 files changed, 9 insertions(+), 16 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index a41c17c7a..0a8f2b015 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -503,6 +503,7 @@ pytype_strict_library( pytype_strict_library( name = "sharding_impls", srcs = ["_src/sharding_impls.py"], + visibility = [":internal"] + jax_visibility("sharding_impls"), deps = [ ":mesh", ":op_shardings", diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index b31cffd61..61519b16d 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2262,7 +2262,7 @@ def get_gspmd_shardings_from_executable( num_in_avals: int, num_out_avals: int ) -> tuple[Sequence[sharding_impls.XLACompatibleSharding], Sequence[sharding_impls.XLACompatibleSharding]]: - from jax.experimental import pjit + from jax._src import pjit # When the device assignment only has 1 device, SPMD partitioner will not run. # Hence the op shardings will not be set on the `hlo_module`. In that case, @@ -2272,7 +2272,7 @@ def get_gspmd_shardings_from_executable( ss = sharding_impls.SingleDeviceSharding(device_assignment[0]) return [ss] * num_in_avals, [ss] * num_out_avals - in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable) + in_op_shardings, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable) in_shardings_xla = [sharding_impls.GSPMDSharding(device_assignment, i) for i in in_op_shardings] @@ -2295,9 +2295,9 @@ def _get_mesh_pspec_shardings_from_executable( xla_executable, mesh: Mesh ) -> tuple[Sequence[sharding_impls.NamedSharding], Sequence[sharding_impls.NamedSharding]]: - from jax.experimental import pjit + from jax._src import pjit - in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh) + in_pspec, out_pspec = pjit.get_pspec_from_executable(xla_executable, mesh) return ([sharding_impls.NamedSharding(mesh, i) for i in in_pspec], [sharding_impls.NamedSharding(mesh, o) for o in out_pspec]) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index e55bc7e3b..512e90745 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1973,7 +1973,7 @@ def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[Parti return [get_single_pspec(p) for p in ppspec] -def _get_op_sharding_from_executable( +def get_op_sharding_from_executable( executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]: in_op_shardings: list[xc.OpSharding] = [] parameter_shardings_from_xla = executable.get_parameter_shardings() @@ -1989,7 +1989,7 @@ def _get_op_sharding_from_executable( def _get_ppspec_from_executable(executable, mesh) -> tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]: - input_op_shardings, output_op_sharding = _get_op_sharding_from_executable( + input_op_shardings, output_op_sharding = get_op_sharding_from_executable( executable ) in_ppspec: list[ParsedPartitionSpec] = [] @@ -2002,7 +2002,7 @@ def _get_ppspec_from_executable(executable, mesh) -> tuple[Sequence[ParsedPartit return in_ppspec, out_ppspec -def _get_pspec_from_executable( +def get_pspec_from_executable( executable, mesh: pxla.Mesh ) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]: in_ppspec, out_ppspec = _get_ppspec_from_executable(executable, mesh) diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 8b08342da..f83d22e03 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -15,7 +15,6 @@ # flake8: noqa from jax._src.pjit import ( - hashable_pytree as hashable_pytree, pjit as pjit, pjit_p as pjit_p, with_sharding_constraint as with_sharding_constraint, @@ -23,13 +22,6 @@ from jax._src.pjit import ( from jax._src.sharding_impls import ( AUTO as AUTO, UNSPECIFIED as _UNSPECIFIED, - ParsedPartitionSpec as ParsedPartitionSpec, - get_array_mapping as get_array_mapping, - prepare_axis_resources as _prepare_axis_resources, - parse_flatten_op_sharding as parse_flatten_op_sharding, ) -from jax._src.pjit import (_get_op_sharding_from_executable, - _get_pspec_from_executable, _pjit_lower_cached, - _pjit_lower, _pjit_jaxpr, - _process_in_axis_resources) +from jax._src.pjit import (_pjit_lower_cached, _pjit_lower)