mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Removed the double re-exporting of Pallas GPU/TPU APIs
jax.experimental.pallas.{gpu,tpu} now import directly from the relevant jax._src.pallas.{triton,mosaic} submodules. PiperOrigin-RevId: 641875127
This commit is contained in:
parent
3b4039c850
commit
5e7ad600e2
14
jax/BUILD
14
jax/BUILD
@ -623,9 +623,14 @@ pytype_strict_library(
|
||||
":pallas_tpu_users",
|
||||
],
|
||||
deps = [
|
||||
":pallas", # buildcleaner: keep
|
||||
":pallas", # build_cleaner: keep
|
||||
":tpu_custom_call",
|
||||
"//jax/_src/pallas/mosaic",
|
||||
"//jax/_src/pallas/mosaic:core",
|
||||
"//jax/_src/pallas/mosaic:kernel_regeneration_util",
|
||||
"//jax/_src/pallas/mosaic:lowering",
|
||||
"//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep
|
||||
"//jax/_src/pallas/mosaic:pipeline",
|
||||
"//jax/_src/pallas/mosaic:primitives",
|
||||
],
|
||||
)
|
||||
|
||||
@ -663,8 +668,9 @@ pytype_strict_library(
|
||||
],
|
||||
deps = [
|
||||
":pallas",
|
||||
"//jax/_src/pallas/mosaic_gpu",
|
||||
"//jax/_src/pallas/triton",
|
||||
"//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep
|
||||
"//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep
|
||||
"//jax/_src/pallas/triton:primitives",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -27,13 +27,13 @@ package(
|
||||
|
||||
py_library(
|
||||
name = "pallas",
|
||||
srcs = glob(
|
||||
include = ["**/*.py"],
|
||||
exclude = [
|
||||
"triton/*.py",
|
||||
"mosaic/*.py",
|
||||
],
|
||||
),
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"core.py",
|
||||
"pallas_call.py",
|
||||
"primitives.py",
|
||||
"utils.py",
|
||||
],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:ad_util",
|
||||
@ -46,21 +46,3 @@ py_library(
|
||||
"//jax/_src/lib",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gpu",
|
||||
visibility = [],
|
||||
deps = [
|
||||
":pallas",
|
||||
"//jax/_src/pallas/triton",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tpu",
|
||||
visibility = [],
|
||||
deps = [
|
||||
":pallas",
|
||||
"//jax/_src/pallas/mosaic",
|
||||
],
|
||||
)
|
||||
|
@ -15,11 +15,7 @@
|
||||
# Package for Mosaic-specific Pallas extensions
|
||||
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"py_deps",
|
||||
"py_library_providing_imports_info",
|
||||
)
|
||||
load("//jaxlib:jax.bzl", "py_deps")
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
@ -28,20 +24,6 @@ package(
|
||||
],
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
name = "mosaic",
|
||||
srcs = ["__init__.py"],
|
||||
lib_rule = py_library,
|
||||
deps = [
|
||||
":core",
|
||||
":kernel_regeneration_util",
|
||||
":lowering",
|
||||
":pallas_call_registration",
|
||||
":pipeline",
|
||||
":primitives",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "core",
|
||||
srcs = ["core.py"],
|
||||
|
@ -11,42 +11,3 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module for Mosaic lowering of Pallas call."""
|
||||
|
||||
from jax._src.pallas.mosaic import core
|
||||
from jax._src.pallas.mosaic.core import dma_semaphore
|
||||
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
|
||||
from jax._src.pallas.mosaic.core import semaphore
|
||||
from jax._src.pallas.mosaic.core import SemaphoreType
|
||||
from jax._src.pallas.mosaic.core import TPUMemorySpace
|
||||
from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata
|
||||
from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata
|
||||
from jax._src.pallas.mosaic.lowering import LoweringException
|
||||
from jax._src.pallas.mosaic.pipeline import BufferedRef
|
||||
from jax._src.pallas.mosaic.pipeline import emit_pipeline
|
||||
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations
|
||||
from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule
|
||||
from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations
|
||||
from jax._src.pallas.mosaic.primitives import async_copy
|
||||
from jax._src.pallas.mosaic.primitives import async_remote_copy
|
||||
from jax._src.pallas.mosaic.primitives import bitcast
|
||||
from jax._src.pallas.mosaic.primitives import delay
|
||||
from jax._src.pallas.mosaic.primitives import device_id
|
||||
from jax._src.pallas.mosaic.primitives import DeviceIdType
|
||||
from jax._src.pallas.mosaic.primitives import get_barrier_semaphore
|
||||
from jax._src.pallas.mosaic.primitives import make_async_copy
|
||||
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
|
||||
from jax._src.pallas.mosaic.primitives import repeat
|
||||
from jax._src.pallas.mosaic.primitives import roll
|
||||
from jax._src.pallas.mosaic.primitives import run_scoped
|
||||
from jax._src.pallas.mosaic.primitives import semaphore_read
|
||||
from jax._src.pallas.mosaic.primitives import semaphore_signal
|
||||
from jax._src.pallas.mosaic.primitives import semaphore_wait
|
||||
from jax._src.pallas.mosaic.primitives import prng_seed
|
||||
from jax._src.pallas.mosaic.primitives import prng_random_bits
|
||||
|
||||
ANY = TPUMemorySpace.ANY
|
||||
CMEM = TPUMemorySpace.CMEM
|
||||
SMEM = TPUMemorySpace.SMEM
|
||||
VMEM = TPUMemorySpace.VMEM
|
||||
|
@ -17,7 +17,6 @@
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"py_deps",
|
||||
"py_library_providing_imports_info",
|
||||
"pytype_strict_library",
|
||||
)
|
||||
|
||||
@ -28,18 +27,6 @@ package(
|
||||
],
|
||||
)
|
||||
|
||||
py_library_providing_imports_info(
|
||||
name = "triton",
|
||||
srcs = ["__init__.py"],
|
||||
lib_rule = pytype_strict_library,
|
||||
deps = [
|
||||
":lowering",
|
||||
":pallas_call_registration",
|
||||
":primitives",
|
||||
"//jax/_src/lib",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "primitives",
|
||||
srcs = ["primitives.py"],
|
||||
|
@ -11,8 +11,3 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Triton-specific Pallas APIs."""
|
||||
|
||||
from jax._src.pallas.triton.primitives import approx_tanh
|
||||
from jax._src.pallas.triton.primitives import elementwise_inline_asm
|
||||
|
@ -14,5 +14,5 @@
|
||||
|
||||
"""Triton-specific Pallas APIs."""
|
||||
|
||||
from jax._src.pallas.triton import approx_tanh
|
||||
from jax._src.pallas.triton import elementwise_inline_asm
|
||||
from jax._src.pallas.triton.primitives import approx_tanh
|
||||
from jax._src.pallas.triton.primitives import elementwise_inline_asm
|
||||
|
@ -12,38 +12,42 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains Mosaic specific Pallas functions."""
|
||||
from jax._src.pallas.mosaic import ANY
|
||||
from jax._src.pallas.mosaic import CMEM
|
||||
from jax._src.pallas.mosaic import PrefetchScalarGridSpec
|
||||
from jax._src.pallas.mosaic import SMEM
|
||||
from jax._src.pallas.mosaic import SemaphoreType
|
||||
from jax._src.pallas.mosaic import TPUMemorySpace
|
||||
from jax._src.pallas.mosaic import VMEM
|
||||
from jax._src.pallas.mosaic import DeviceIdType
|
||||
from jax._src.pallas.mosaic import async_copy
|
||||
from jax._src.pallas.mosaic import async_remote_copy
|
||||
from jax._src.pallas.mosaic import bitcast
|
||||
from jax._src.pallas.mosaic import dma_semaphore
|
||||
from jax._src.pallas.mosaic import delay
|
||||
from jax._src.pallas.mosaic import device_id
|
||||
from jax._src.pallas.mosaic import emit_pipeline_with_allocations
|
||||
from jax._src.pallas.mosaic import emit_pipeline
|
||||
from jax._src.pallas.mosaic import get_pipeline_schedule
|
||||
from jax._src.pallas.mosaic import make_pipeline_allocations
|
||||
from jax._src.pallas.mosaic import BufferedRef
|
||||
from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata
|
||||
from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata
|
||||
from jax._src.pallas.mosaic import get_barrier_semaphore
|
||||
from jax._src.pallas.mosaic import make_async_copy
|
||||
from jax._src.pallas.mosaic import make_async_remote_copy
|
||||
from jax._src.pallas.mosaic import repeat
|
||||
from jax._src.pallas.mosaic import roll
|
||||
from jax._src.pallas.mosaic import run_scoped
|
||||
from jax._src.pallas.mosaic import semaphore
|
||||
from jax._src.pallas.mosaic import semaphore_read
|
||||
from jax._src.pallas.mosaic import semaphore_signal
|
||||
from jax._src.pallas.mosaic import semaphore_wait
|
||||
"""Mosaic-specific Pallas APIs."""
|
||||
|
||||
from jax._src.pallas.mosaic import core
|
||||
from jax._src.pallas.mosaic.core import dma_semaphore
|
||||
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
|
||||
from jax._src.pallas.mosaic.core import semaphore
|
||||
from jax._src.pallas.mosaic.core import SemaphoreType
|
||||
from jax._src.pallas.mosaic.core import TPUMemorySpace
|
||||
from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata
|
||||
from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata
|
||||
from jax._src.pallas.mosaic.lowering import LoweringException
|
||||
from jax._src.pallas.mosaic.pipeline import BufferedRef
|
||||
from jax._src.pallas.mosaic.pipeline import emit_pipeline
|
||||
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations
|
||||
from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule
|
||||
from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations
|
||||
from jax._src.pallas.mosaic.primitives import async_copy
|
||||
from jax._src.pallas.mosaic.primitives import async_remote_copy
|
||||
from jax._src.pallas.mosaic.primitives import bitcast
|
||||
from jax._src.pallas.mosaic.primitives import delay
|
||||
from jax._src.pallas.mosaic.primitives import device_id
|
||||
from jax._src.pallas.mosaic.primitives import DeviceIdType
|
||||
from jax._src.pallas.mosaic.primitives import get_barrier_semaphore
|
||||
from jax._src.pallas.mosaic.primitives import make_async_copy
|
||||
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
|
||||
from jax._src.pallas.mosaic.primitives import repeat
|
||||
from jax._src.pallas.mosaic.primitives import roll
|
||||
from jax._src.pallas.mosaic.primitives import run_scoped
|
||||
from jax._src.pallas.mosaic.primitives import semaphore_read
|
||||
from jax._src.pallas.mosaic.primitives import semaphore_signal
|
||||
from jax._src.pallas.mosaic.primitives import semaphore_wait
|
||||
from jax._src.pallas.mosaic.primitives import prng_seed
|
||||
from jax._src.pallas.mosaic.primitives import prng_random_bits
|
||||
from jax._src.tpu_custom_call import CostEstimate
|
||||
from jax._src.pallas.mosaic import prng_seed
|
||||
from jax._src.pallas.mosaic import prng_random_bits
|
||||
|
||||
ANY = TPUMemorySpace.ANY
|
||||
CMEM = TPUMemorySpace.CMEM
|
||||
SMEM = TPUMemorySpace.SMEM
|
||||
VMEM = TPUMemorySpace.VMEM
|
||||
|
Loading…
x
Reference in New Issue
Block a user