Added jax.experimental.pallas.mosaic_gpu

I also deprecated `jax.experimental.pallas.gpu` in favor of
`jax.experimental.pallas.triton` to avoid confusion with the Mosaic GPU
backend.

PiperOrigin-RevId: 683119193
This commit is contained in:
Sergei Lebedev 2024-10-07 04:04:16 -07:00 committed by jax authors
parent 6d2c8cf5de
commit 95631a7d92
15 changed files with 100 additions and 36 deletions

View File

@ -23,6 +23,10 @@ Remember to align the itemized text with the first line of an item within a list
* Deprecations
* The {mod}`jax.experimental.pallas.gpu` submodule is deprecated to avoid
ambiguite with {mod}`jax.experimental.pallas.mosaic_gpu`. To use the
Triton backend import {mod}`jax.experimental.pallas.triton`.
* New functionality
* {func}`jax.experimental.pallas.pallas_call` now accepts `scratch_shapes`,

View File

@ -587,9 +587,11 @@ pytype_strict_library(
],
exclude = [
"experimental/pallas/gpu.py",
"experimental/pallas/tpu.py",
"experimental/pallas/mosaic_gpu.py",
"experimental/pallas/ops/gpu/**/*.py",
"experimental/pallas/ops/tpu/**/*.py",
"experimental/pallas/tpu.py",
"experimental/pallas/triton.py",
],
),
visibility = [
@ -649,21 +651,38 @@ pytype_strict_library(
] + py_deps("numpy"),
)
# TODO(slebedev): Rename to :pallas_triton and update all users. Reserve :pallas_gpu
# for both GPU backends.
pytype_strict_library(
name = "pallas_gpu",
srcs = ["experimental/pallas/gpu.py"],
srcs = [
"experimental/pallas/gpu.py",
"experimental/pallas/triton.py",
],
visibility = [
":pallas_gpu_users",
],
deps = [
"//jax/_src/pallas/mosaic_gpu:core", # build_cleaner: keep
"//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep
":deprecations",
"//jax/_src/pallas/triton:core",
"//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/triton:primitives",
],
)
pytype_strict_library(
name = "pallas_mosaic_gpu",
srcs = ["experimental/pallas/mosaic_gpu.py"],
visibility = [
":mosaic_gpu_users",
],
deps = [
"//jax/_src/pallas/mosaic_gpu:core",
"//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/mosaic_gpu:primitives",
],
)
# This target only supports sm_90 GPUs.
py_library(
name = "mosaic_gpu",

View File

@ -133,3 +133,4 @@ register('jax-numpy-linalg-matrix_rank-tol')
register('jax-numpy-linalg-pinv-rcond')
register('jax-numpy-quantile-interpolation')
register('jax-numpy-trimzeros-not-1d-array')
register('pallas-gpu-triton')

View File

@ -11,22 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(slebedev): Move these imports to ``jax.experimental.pallas``.
from jax._src.pallas.mosaic_gpu.core import Barrier
from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec
from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
from jax._src.pallas.mosaic_gpu.core import TilingTransform
from jax._src.pallas.mosaic_gpu.core import TransposeTransform
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import wait_barrier
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import wgmma
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait
GMEM = GPUMemorySpace.GMEM
SMEM = GPUMemorySpace.SMEM

View File

@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Triton-specific Pallas APIs."""
from jax._src import deprecations
from jax._src.pallas.triton.core import TritonCompilerParams
from jax._src.pallas.triton.primitives import approx_tanh
from jax._src.pallas.triton.primitives import debug_barrier
from jax._src.pallas.triton.primitives import elementwise_inline_asm
deprecations.warn(
"pallas-gpu-triton",
"The ``jax.experimental.pallas.gpu`` submodule is deprecated. "
" Use ``jax.experimental.pallas.triton`` instead.",
stacklevel=1,
)
from jax.experimental.pallas.triton import * # noqa: F403

View File

@ -0,0 +1,35 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental GPU backend for Pallas targeting H100.
These APIs are highly unstable and can change weekly. Use at your own risk.
"""
from jax._src.pallas.mosaic_gpu.core import Barrier
from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec
from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace
from jax._src.pallas.mosaic_gpu.core import TilingTransform
from jax._src.pallas.mosaic_gpu.core import TransposeTransform
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import wait_barrier
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import wgmma
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait
GMEM = GPUMemorySpace.GMEM
SMEM = GPUMemorySpace.SMEM

View File

@ -21,7 +21,7 @@ from typing import Any
import jax
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import gpu as plgpu
from jax.experimental.pallas import triton as plgpu
import jax.numpy as jnp
import numpy as np

View File

@ -21,7 +21,7 @@ from typing import Any
import jax
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import gpu as plgpu
from jax.experimental.pallas import triton as plgpu
import jax.numpy as jnp

View File

@ -24,7 +24,7 @@ import jax.numpy as jnp
from jax._src.lax.control_flow.for_loop import for_loop
from jax.experimental import pallas as pl
from jax.experimental.pallas import gpu as plgpu
from jax.experimental.pallas import triton as plgpu
def layer_norm_forward_kernel(
x_ref, weight_ref, bias_ref, # Input arrays

View File

@ -26,7 +26,7 @@ import jax.numpy as jnp
from jax._src.lax.control_flow.for_loop import for_loop
from jax.experimental import pallas as pl
from jax.experimental.pallas import gpu as plgpu
from jax.experimental.pallas import triton as plgpu
def rms_norm_forward_kernel(
x_ref, weight_ref, bias_ref, # Input arrays

View File

@ -18,7 +18,7 @@ import functools
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import gpu as plgpu
from jax.experimental.pallas import triton as plgpu
def _vmappable_softmax_kernel(

View File

@ -0,0 +1,20 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Triton-specific Pallas APIs."""
from jax._src.pallas.triton.core import TritonCompilerParams
from jax._src.pallas.triton.primitives import approx_tanh
from jax._src.pallas.triton.primitives import debug_barrier
from jax._src.pallas.triton.primitives import elementwise_inline_asm

View File

@ -165,7 +165,7 @@ jax_multiplatform_test(
},
deps = [
"//jax:pallas",
"//jax:pallas_gpu", # build_cleaner: keep
"//jax:pallas_mosaic_gpu", # build_cleaner: keep
"//jax/_src/pallas/mosaic_gpu",
] + py_deps("absl/testing") + py_deps("numpy"),
)

View File

@ -21,8 +21,8 @@ from absl.testing import parameterized
import jax
from jax._src import config
from jax._src import test_util as jtu
import jax._src.pallas.mosaic_gpu as plgpu
from jax.experimental import pallas as pl
from jax.experimental.pallas import mosaic_gpu as plgpu
import jax.numpy as jnp
import numpy as np

View File

@ -39,7 +39,7 @@ from jax.interpreters import partial_eval as pe
from jax.experimental import pallas as pl
if sys.platform != "win32":
from jax.experimental.pallas import gpu as plgpu
from jax.experimental.pallas import triton as plgpu
from jax.experimental.pallas import tpu as pltpu
else:
plgpu = None