Removed get_compute_capability from jax.experimental.pallas.gpu

Compute capability is available as a `str` attribute on a GPU device since
jaxlib 0.4.26.
This commit is contained in:
Sergei Lebedev 2024-05-08 20:29:18 +01:00
parent a145109ac2
commit 575ba942e0
10 changed files with 29 additions and 100 deletions

View File

@ -11,6 +11,9 @@ Remember to align the itemized text with the first line of an item within a list
* Deprecations & removals
* The ``kind`` argument to {func}`jax.numpy.sort` and {func}`jax.numpy.argsort`
is now removed. Use `stable=True` or `stable=False` instead.
* Removed ``get_compute_capability`` from the ``jax.experimental.pallas.gpu``
module. Use the ``compute_capability`` attribute of a GPU device, returned
by {func}`jax.devices` or {func}`jax.local_devices`, instead.
## jaxlib 0.4.28

View File

@ -12,24 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains Triton-specific pallas modules."""
"""Triton-specific Pallas APIs."""
from jax._src.lib import gpu_triton as triton_kernel_call_lib
from jax._src.pallas.triton.primitives import approx_tanh
from jax._src.pallas.triton.primitives import elementwise_inline_asm
try:
get_compute_capability = triton_kernel_call_lib.get_compute_capability
except AttributeError:
def get_compute_capability(device) -> int:
del device # Unused.
raise RuntimeError(
"get_compute_capability is not available. Try installing jaxlib with"
" GPU support following instructions in"
" https://jax.readthedocs.io/en/latest/installation.html."
)
del triton_kernel_call_lib

View File

@ -25,7 +25,6 @@ from typing import Any
import jax
from jax import core as jax_core
from jax._src.interpreters import mlir
from jax._src.lib import gpu_triton as triton_kernel_call_lib
from jax._src.lib.mlir import ir
from jax._src.pallas import core as pallas_core
from jax._src.pallas.pallas_call import pallas_call_p
@ -58,13 +57,10 @@ def _pallas_call_ttir_lowering(
num_warps: int,
num_stages: int,
):
# TODO(sharadmv): handle multiple devices, right now we assume device 0
# which is fine when we have multiple of the same GPU but this won't work in
# general.
device = 0
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
# TODO(sharadmv): Handle multiple devices with different capabilities.
d, *_ = jax.local_devices(backend="gpu")
cuda_options = dict(
compute_capability=compute_capability,
compute_capability=d.compute_capability,
num_warps=num_warps,
num_stages=num_stages,
debug=debug,

View File

@ -419,6 +419,12 @@ def is_device_tpu(version: int | None = None, variant: str = "") -> bool:
return "v5 lite" in device_kind
return expected_version in device_kind
def is_device_gpu_at_least(capability: str) -> bool:
if device_under_test() != "gpu":
return False
d, *_ = jax.local_devices(backend="gpu")
return d.compute_capability >= capability
def _get_device_tags():
"""returns a set of tags defined for the device under test"""
if is_device_rocm():

View File

@ -12,11 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains Triton specific Pallas functions."""
from jax._src.pallas import triton
"""Triton-specific Pallas APIs."""
from jax._src.pallas.triton import approx_tanh
from jax._src.pallas.triton import elementwise_inline_asm
get_compute_capability = triton.get_compute_capability
del triton

View File

@ -26,7 +26,6 @@ from jax._src import test_util as jtu
from jax._src.internal_test_util import export_back_compat_test_util as bctu
from jax._src.internal_test_util.export_back_compat_test_data.pallas import cuda_add_one
from jax.experimental import pallas as pl
from jax.experimental.pallas import gpu as plgpu
config.parse_flags_with_absl()
@ -40,7 +39,7 @@ class CompatTest(bctu.CompatTestBase):
if not jtu.test_device_matches(["gpu"]):
self.skipTest("Only works on GPU")
if (jtu.test_device_matches(["cuda"]) and
plgpu.get_compute_capability(0) < 80):
not jtu.is_device_gpu_at_least("8.0")):
self.skipTest("Only works on GPUs with capability >= sm80")
super().setUp()

View File

@ -13,7 +13,6 @@
# limitations under the License.
import os
import unittest
from absl.testing import absltest
from absl.testing import parameterized
@ -28,10 +27,6 @@ import numpy as np
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
try:
from jax.experimental.pallas import gpu as plgpu
except ImportError:
pass
# pylint: disable=no-value-for-parameter
@ -39,21 +34,12 @@ config.update("jax_traceback_filtering", "off")
config.parse_flags_with_absl()
class PallasTest(jtu.JaxTestCase):
def check_gpu_capability_at_least(self, capability, device: int = 0):
return plgpu.get_compute_capability(device) >= capability
class DecodeAttentionTest(jtu.JaxTestCase):
def setUp(self):
if not jtu.test_device_matches(["gpu"]):
self.skipTest("Only works on GPU")
try:
import triton # noqa: F401
except ImportError:
self.skipTest("Triton is not installed. Skipping PallasTest.")
super().setUp()
class DecodeAttentionTest(PallasTest):
if not jtu.is_device_gpu_at_least("8.0"):
self.skipTest("Fused attention only works on GPUs with capability >= sm80")
@parameterized.named_parameters(*[
(
@ -86,10 +72,6 @@ class DecodeAttentionTest(PallasTest):
kwargs,
):
del kwargs
if not self.check_gpu_capability_at_least(80):
raise unittest.SkipTest(
"Fused attention only works on GPUs with capability >= sm80"
)
k1, k2, k3 = random.split(random.key(0), 3)
q = random.normal(k1, (batch_size, num_heads, head_dim), dtype=jnp.float16)
@ -134,10 +116,6 @@ class DecodeAttentionTest(PallasTest):
kwargs,
):
del kwargs
if not self.check_gpu_capability_at_least(80):
raise unittest.SkipTest(
"Fused attention only works on GPUs with capability >= sm80"
)
k1, k2, k3 = random.split(random.key(0), 3)
q = random.normal(

View File

@ -25,10 +25,7 @@ import jax.numpy as jnp
from jax import lax
from jax._src import test_util as jtu
from jax.experimental import pallas as pl
try:
from jax.experimental.pallas import gpu as plgpu
except ImportError:
plgpu = None
jax.config.parse_flags_with_absl()
@ -40,19 +37,12 @@ class OpsTest(jtu.JaxTestCase):
super().setUp()
if jax.config.x64_enabled:
self.skipTest("Only works in 32-bit")
if jtu.device_under_test() == "cpu" and not self.INTERPRET:
self.skipTest("Only interpreter mode supported on CPU")
if (jtu.test_device_matches(["cuda"]) and
not self.check_gpu_capability_at_least(80)):
self.skipTest("Only works on GPUs with capability >= sm80")
def check_gpu_capability_at_least(self, capability,
device: int = 0):
if plgpu is None:
return False
if self.INTERPRET:
return True
return plgpu.get_compute_capability(device) >= capability
if not self.INTERPRET:
if jtu.device_under_test() == "cpu":
self.skipTest("Only interpreter mode supported on CPU")
if (jtu.test_device_matches(["cuda"]) and
not jtu.is_device_gpu_at_least("8.0")):
self.skipTest("Only works on GPUs with capability >= sm80")
@classmethod
def pallas_call(cls, *args, **kwargs):

View File

@ -134,7 +134,7 @@ class PallasTest(parameterized.TestCase):
if not jtu.test_device_matches(["gpu"]):
self.skipTest("Only works on GPU")
if (jtu.test_device_matches(["cuda"]) and
not self.check_gpu_capability_at_least(80)):
not jtu.is_device_gpu_at_least("8.0")):
self.skipTest("Only works on GPUs with capability >= sm80")
super().setUp()
@ -143,12 +143,6 @@ class PallasTest(parameterized.TestCase):
def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
def check_gpu_capability_at_least(self, capability,
device: int = 0):
if self.INTERPRET:
return True
return plgpu.get_compute_capability(device) >= capability
class PallasCallTest(PallasTest):
@ -336,11 +330,6 @@ class PallasCallTest(PallasTest):
if block_size_m <= m and block_size_n <= n and block_size_k <= k
])
def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm):
if not self.INTERPRET and (
plgpu.get_compute_capability(0) <= 75
and (bm >= 128 or bn > 128 or bk > 32)
):
raise unittest.SkipTest("Block sizes too big for sm70.")
k1, k2 = random.split(random.key(0))
x = random.normal(k1, (m, k), dtype=dtype)
y = random.normal(k2, (k, n), dtype=dtype)
@ -362,12 +351,6 @@ class PallasCallTest(PallasTest):
if block_size_m <= m and block_size_n <= n and block_size_k <= k
])
def test_matmul_block_spec(self, m, n, k, dtype, bm, bn, bk):
if not self.INTERPRET and (
plgpu.get_compute_capability(0) <= 75
and (bm >= 128 or bn > 128 or bk > 32)
):
raise unittest.SkipTest("Block sizes too big for sm70.")
k1, k2 = random.split(random.key(0))
x = random.normal(k1, (m, k), dtype=dtype)
y = random.normal(k2, (k, n), dtype=dtype)
@ -1666,7 +1649,7 @@ class PallasOpsTest(PallasTest):
def test_approx_tanh(self, dtype):
if self.INTERPRET:
self.skipTest("approx_tanh is not supported in interpreter mode")
if dtype == "bfloat16" and not self.check_gpu_capability_at_least(90):
if dtype == "bfloat16" and not jtu.is_device_gpu_at_least("9.0"):
self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90")
@functools.partial(

View File

@ -22,7 +22,6 @@ import jax
import jax.numpy as jnp
from jax import dtypes
from jax._src import test_util as jtu
from jax.experimental.pallas import gpu as plgpu
from jax.experimental.sparse import nm
jax.config.parse_flags_with_absl()
@ -33,14 +32,10 @@ class SpmmTest(jtu.JaxTestCase):
if not jtu.test_device_matches(["gpu"]):
self.skipTest("Only works on GPU")
if (jtu.test_device_matches(["cuda"]) and
not self.check_gpu_capability_at_least(80)):
not jtu.is_device_gpu_at_least("8.0")):
self.skipTest("Only works on GPUs with capability >= sm80")
super().setUp()
def check_gpu_capability_at_least(self, capability,
device: int = 0):
return plgpu.get_compute_capability(device) >= capability
# ----- Test different input shapes
@parameterized.product(
tile_m=(32, 128),