Removed the double re-exporting of Pallas GPU/TPU APIs

jax.experimental.pallas.{gpu,tpu} now import directly from the relevant
jax._src.pallas.{triton,mosaic} submodules.

PiperOrigin-RevId: 641875127
This commit is contained in:
Sergei Lebedev 2024-06-10 05:58:35 -07:00 committed by jax authors
parent 3b4039c850
commit 5e7ad600e2
8 changed files with 58 additions and 141 deletions

View File

@ -623,9 +623,14 @@ pytype_strict_library(
":pallas_tpu_users",
],
deps = [
":pallas", # buildcleaner: keep
":pallas", # build_cleaner: keep
":tpu_custom_call",
"//jax/_src/pallas/mosaic",
"//jax/_src/pallas/mosaic:core",
"//jax/_src/pallas/mosaic:kernel_regeneration_util",
"//jax/_src/pallas/mosaic:lowering",
"//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/mosaic:pipeline",
"//jax/_src/pallas/mosaic:primitives",
],
)
@ -663,8 +668,9 @@ pytype_strict_library(
],
deps = [
":pallas",
"//jax/_src/pallas/mosaic_gpu",
"//jax/_src/pallas/triton",
"//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/triton:primitives",
],
)

View File

@ -27,13 +27,13 @@ package(
py_library(
name = "pallas",
srcs = glob(
include = ["**/*.py"],
exclude = [
"triton/*.py",
"mosaic/*.py",
],
),
srcs = [
"__init__.py",
"core.py",
"pallas_call.py",
"primitives.py",
"utils.py",
],
deps = [
"//jax",
"//jax:ad_util",
@ -46,21 +46,3 @@ py_library(
"//jax/_src/lib",
] + py_deps("numpy"),
)
py_library(
name = "gpu",
visibility = [],
deps = [
":pallas",
"//jax/_src/pallas/triton",
],
)
py_library(
name = "tpu",
visibility = [],
deps = [
":pallas",
"//jax/_src/pallas/mosaic",
],
)

View File

@ -15,11 +15,7 @@
# Package for Mosaic-specific Pallas extensions
load("@rules_python//python:defs.bzl", "py_library")
load(
"//jaxlib:jax.bzl",
"py_deps",
"py_library_providing_imports_info",
)
load("//jaxlib:jax.bzl", "py_deps")
package(
default_applicable_licenses = [],
@ -28,20 +24,6 @@ package(
],
)
py_library_providing_imports_info(
name = "mosaic",
srcs = ["__init__.py"],
lib_rule = py_library,
deps = [
":core",
":kernel_regeneration_util",
":lowering",
":pallas_call_registration",
":pipeline",
":primitives",
],
)
py_library(
name = "core",
srcs = ["core.py"],

View File

@ -11,42 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for Mosaic lowering of Pallas call."""
from jax._src.pallas.mosaic import core
from jax._src.pallas.mosaic.core import dma_semaphore
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
from jax._src.pallas.mosaic.core import semaphore
from jax._src.pallas.mosaic.core import SemaphoreType
from jax._src.pallas.mosaic.core import TPUMemorySpace
from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata
from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata
from jax._src.pallas.mosaic.lowering import LoweringException
from jax._src.pallas.mosaic.pipeline import BufferedRef
from jax._src.pallas.mosaic.pipeline import emit_pipeline
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations
from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule
from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations
from jax._src.pallas.mosaic.primitives import async_copy
from jax._src.pallas.mosaic.primitives import async_remote_copy
from jax._src.pallas.mosaic.primitives import bitcast
from jax._src.pallas.mosaic.primitives import delay
from jax._src.pallas.mosaic.primitives import device_id
from jax._src.pallas.mosaic.primitives import DeviceIdType
from jax._src.pallas.mosaic.primitives import get_barrier_semaphore
from jax._src.pallas.mosaic.primitives import make_async_copy
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
from jax._src.pallas.mosaic.primitives import repeat
from jax._src.pallas.mosaic.primitives import roll
from jax._src.pallas.mosaic.primitives import run_scoped
from jax._src.pallas.mosaic.primitives import semaphore_read
from jax._src.pallas.mosaic.primitives import semaphore_signal
from jax._src.pallas.mosaic.primitives import semaphore_wait
from jax._src.pallas.mosaic.primitives import prng_seed
from jax._src.pallas.mosaic.primitives import prng_random_bits
ANY = TPUMemorySpace.ANY
CMEM = TPUMemorySpace.CMEM
SMEM = TPUMemorySpace.SMEM
VMEM = TPUMemorySpace.VMEM

View File

@ -17,7 +17,6 @@
load(
"//jaxlib:jax.bzl",
"py_deps",
"py_library_providing_imports_info",
"pytype_strict_library",
)
@ -28,18 +27,6 @@ package(
],
)
py_library_providing_imports_info(
name = "triton",
srcs = ["__init__.py"],
lib_rule = pytype_strict_library,
deps = [
":lowering",
":pallas_call_registration",
":primitives",
"//jax/_src/lib",
],
)
pytype_strict_library(
name = "primitives",
srcs = ["primitives.py"],

View File

@ -11,8 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Triton-specific Pallas APIs."""
from jax._src.pallas.triton.primitives import approx_tanh
from jax._src.pallas.triton.primitives import elementwise_inline_asm

View File

@ -14,5 +14,5 @@
"""Triton-specific Pallas APIs."""
from jax._src.pallas.triton import approx_tanh
from jax._src.pallas.triton import elementwise_inline_asm
from jax._src.pallas.triton.primitives import approx_tanh
from jax._src.pallas.triton.primitives import elementwise_inline_asm

View File

@ -12,38 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains Mosaic specific Pallas functions."""
from jax._src.pallas.mosaic import ANY
from jax._src.pallas.mosaic import CMEM
from jax._src.pallas.mosaic import PrefetchScalarGridSpec
from jax._src.pallas.mosaic import SMEM
from jax._src.pallas.mosaic import SemaphoreType
from jax._src.pallas.mosaic import TPUMemorySpace
from jax._src.pallas.mosaic import VMEM
from jax._src.pallas.mosaic import DeviceIdType
from jax._src.pallas.mosaic import async_copy
from jax._src.pallas.mosaic import async_remote_copy
from jax._src.pallas.mosaic import bitcast
from jax._src.pallas.mosaic import dma_semaphore
from jax._src.pallas.mosaic import delay
from jax._src.pallas.mosaic import device_id
from jax._src.pallas.mosaic import emit_pipeline_with_allocations
from jax._src.pallas.mosaic import emit_pipeline
from jax._src.pallas.mosaic import get_pipeline_schedule
from jax._src.pallas.mosaic import make_pipeline_allocations
from jax._src.pallas.mosaic import BufferedRef
from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata
from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata
from jax._src.pallas.mosaic import get_barrier_semaphore
from jax._src.pallas.mosaic import make_async_copy
from jax._src.pallas.mosaic import make_async_remote_copy
from jax._src.pallas.mosaic import repeat
from jax._src.pallas.mosaic import roll
from jax._src.pallas.mosaic import run_scoped
from jax._src.pallas.mosaic import semaphore
from jax._src.pallas.mosaic import semaphore_read
from jax._src.pallas.mosaic import semaphore_signal
from jax._src.pallas.mosaic import semaphore_wait
"""Mosaic-specific Pallas APIs."""
from jax._src.pallas.mosaic import core
from jax._src.pallas.mosaic.core import dma_semaphore
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
from jax._src.pallas.mosaic.core import semaphore
from jax._src.pallas.mosaic.core import SemaphoreType
from jax._src.pallas.mosaic.core import TPUMemorySpace
from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata
from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata
from jax._src.pallas.mosaic.lowering import LoweringException
from jax._src.pallas.mosaic.pipeline import BufferedRef
from jax._src.pallas.mosaic.pipeline import emit_pipeline
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations
from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule
from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations
from jax._src.pallas.mosaic.primitives import async_copy
from jax._src.pallas.mosaic.primitives import async_remote_copy
from jax._src.pallas.mosaic.primitives import bitcast
from jax._src.pallas.mosaic.primitives import delay
from jax._src.pallas.mosaic.primitives import device_id
from jax._src.pallas.mosaic.primitives import DeviceIdType
from jax._src.pallas.mosaic.primitives import get_barrier_semaphore
from jax._src.pallas.mosaic.primitives import make_async_copy
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
from jax._src.pallas.mosaic.primitives import repeat
from jax._src.pallas.mosaic.primitives import roll
from jax._src.pallas.mosaic.primitives import run_scoped
from jax._src.pallas.mosaic.primitives import semaphore_read
from jax._src.pallas.mosaic.primitives import semaphore_signal
from jax._src.pallas.mosaic.primitives import semaphore_wait
from jax._src.pallas.mosaic.primitives import prng_seed
from jax._src.pallas.mosaic.primitives import prng_random_bits
from jax._src.tpu_custom_call import CostEstimate
from jax._src.pallas.mosaic import prng_seed
from jax._src.pallas.mosaic import prng_random_bits
ANY = TPUMemorySpace.ANY
CMEM = TPUMemorySpace.CMEM
SMEM = TPUMemorySpace.SMEM
VMEM = TPUMemorySpace.VMEM