mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Added jax.experimental.pallas.mosaic_gpu
I also deprecated `jax.experimental.pallas.gpu` in favor of `jax.experimental.pallas.triton` to avoid confusion with the Mosaic GPU backend. PiperOrigin-RevId: 683119193
This commit is contained in:
parent
6d2c8cf5de
commit
95631a7d92
@ -23,6 +23,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
* Deprecations
|
||||
|
||||
* The {mod}`jax.experimental.pallas.gpu` submodule is deprecated to avoid
|
||||
ambiguite with {mod}`jax.experimental.pallas.mosaic_gpu`. To use the
|
||||
Triton backend import {mod}`jax.experimental.pallas.triton`.
|
||||
|
||||
* New functionality
|
||||
|
||||
* {func}`jax.experimental.pallas.pallas_call` now accepts `scratch_shapes`,
|
||||
|
27
jax/BUILD
27
jax/BUILD
@ -587,9 +587,11 @@ pytype_strict_library(
|
||||
],
|
||||
exclude = [
|
||||
"experimental/pallas/gpu.py",
|
||||
"experimental/pallas/tpu.py",
|
||||
"experimental/pallas/mosaic_gpu.py",
|
||||
"experimental/pallas/ops/gpu/**/*.py",
|
||||
"experimental/pallas/ops/tpu/**/*.py",
|
||||
"experimental/pallas/tpu.py",
|
||||
"experimental/pallas/triton.py",
|
||||
],
|
||||
),
|
||||
visibility = [
|
||||
@ -649,21 +651,38 @@ pytype_strict_library(
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
# TODO(slebedev): Rename to :pallas_triton and update all users. Reserve :pallas_gpu
|
||||
# for both GPU backends.
|
||||
pytype_strict_library(
|
||||
name = "pallas_gpu",
|
||||
srcs = ["experimental/pallas/gpu.py"],
|
||||
srcs = [
|
||||
"experimental/pallas/gpu.py",
|
||||
"experimental/pallas/triton.py",
|
||||
],
|
||||
visibility = [
|
||||
":pallas_gpu_users",
|
||||
],
|
||||
deps = [
|
||||
"//jax/_src/pallas/mosaic_gpu:core", # build_cleaner: keep
|
||||
"//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep
|
||||
":deprecations",
|
||||
"//jax/_src/pallas/triton:core",
|
||||
"//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep
|
||||
"//jax/_src/pallas/triton:primitives",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "pallas_mosaic_gpu",
|
||||
srcs = ["experimental/pallas/mosaic_gpu.py"],
|
||||
visibility = [
|
||||
":mosaic_gpu_users",
|
||||
],
|
||||
deps = [
|
||||
"//jax/_src/pallas/mosaic_gpu:core",
|
||||
"//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep
|
||||
"//jax/_src/pallas/mosaic_gpu:primitives",
|
||||
],
|
||||
)
|
||||
|
||||
# This target only supports sm_90 GPUs.
|
||||
py_library(
|
||||
name = "mosaic_gpu",
|
||||
|
@ -133,3 +133,4 @@ register('jax-numpy-linalg-matrix_rank-tol')
|
||||
register('jax-numpy-linalg-pinv-rcond')
|
||||
register('jax-numpy-quantile-interpolation')
|
||||
register('jax-numpy-trimzeros-not-1d-array')
|
||||
register('pallas-gpu-triton')
|
||||
|
@ -11,22 +11,3 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO(slebedev): Move these imports to ``jax.experimental.pallas``.
|
||||
|
||||
from jax._src.pallas.mosaic_gpu.core import Barrier
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
|
||||
from jax._src.pallas.mosaic_gpu.core import TilingTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import TransposeTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC
|
||||
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wait_barrier
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait
|
||||
|
||||
GMEM = GPUMemorySpace.GMEM
|
||||
SMEM = GPUMemorySpace.SMEM
|
||||
|
@ -12,9 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Triton-specific Pallas APIs."""
|
||||
from jax._src import deprecations
|
||||
|
||||
from jax._src.pallas.triton.core import TritonCompilerParams
|
||||
from jax._src.pallas.triton.primitives import approx_tanh
|
||||
from jax._src.pallas.triton.primitives import debug_barrier
|
||||
from jax._src.pallas.triton.primitives import elementwise_inline_asm
|
||||
deprecations.warn(
|
||||
"pallas-gpu-triton",
|
||||
"The ``jax.experimental.pallas.gpu`` submodule is deprecated. "
|
||||
" Use ``jax.experimental.pallas.triton`` instead.",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
from jax.experimental.pallas.triton import * # noqa: F403
|
||||
|
35
jax/experimental/pallas/mosaic_gpu.py
Normal file
35
jax/experimental/pallas/mosaic_gpu.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Experimental GPU backend for Pallas targeting H100.
|
||||
|
||||
These APIs are highly unstable and can change weekly. Use at your own risk.
|
||||
"""
|
||||
|
||||
from jax._src.pallas.mosaic_gpu.core import Barrier
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams
|
||||
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
|
||||
from jax._src.pallas.mosaic_gpu.core import TilingTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import TransposeTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC
|
||||
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wait_barrier
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait
|
||||
|
||||
GMEM = GPUMemorySpace.GMEM
|
||||
SMEM = GPUMemorySpace.SMEM
|
@ -21,7 +21,7 @@ from typing import Any
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
from jax.experimental.pallas import triton as plgpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
@ -21,7 +21,7 @@ from typing import Any
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
from jax.experimental.pallas import triton as plgpu
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
|
@ -24,7 +24,7 @@ import jax.numpy as jnp
|
||||
from jax._src.lax.control_flow.for_loop import for_loop
|
||||
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
from jax.experimental.pallas import triton as plgpu
|
||||
|
||||
def layer_norm_forward_kernel(
|
||||
x_ref, weight_ref, bias_ref, # Input arrays
|
||||
|
@ -26,7 +26,7 @@ import jax.numpy as jnp
|
||||
from jax._src.lax.control_flow.for_loop import for_loop
|
||||
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
from jax.experimental.pallas import triton as plgpu
|
||||
|
||||
def rms_norm_forward_kernel(
|
||||
x_ref, weight_ref, bias_ref, # Input arrays
|
||||
|
@ -18,7 +18,7 @@ import functools
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
from jax.experimental.pallas import triton as plgpu
|
||||
|
||||
|
||||
def _vmappable_softmax_kernel(
|
||||
|
20
jax/experimental/pallas/triton.py
Normal file
20
jax/experimental/pallas/triton.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Triton-specific Pallas APIs."""
|
||||
|
||||
from jax._src.pallas.triton.core import TritonCompilerParams
|
||||
from jax._src.pallas.triton.primitives import approx_tanh
|
||||
from jax._src.pallas.triton.primitives import debug_barrier
|
||||
from jax._src.pallas.triton.primitives import elementwise_inline_asm
|
@ -165,7 +165,7 @@ jax_multiplatform_test(
|
||||
},
|
||||
deps = [
|
||||
"//jax:pallas",
|
||||
"//jax:pallas_gpu", # build_cleaner: keep
|
||||
"//jax:pallas_mosaic_gpu", # build_cleaner: keep
|
||||
"//jax/_src/pallas/mosaic_gpu",
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
)
|
||||
|
@ -21,8 +21,8 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
import jax._src.pallas.mosaic_gpu as plgpu
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import mosaic_gpu as plgpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
@ -39,7 +39,7 @@ from jax.interpreters import partial_eval as pe
|
||||
from jax.experimental import pallas as pl
|
||||
|
||||
if sys.platform != "win32":
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
from jax.experimental.pallas import triton as plgpu
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
else:
|
||||
plgpu = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user