From 575ba942e045097cdb4e78a62cb56502832795cf Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 8 May 2024 20:29:18 +0100 Subject: [PATCH] Removed get_compute_capability from jax.experimental.pallas.gpu Compute capability is available as a `str` attribute on a GPU device since jaxlib 0.4.26. --- CHANGELOG.md | 3 ++ jax/_src/pallas/triton/__init__.py | 19 +------------ .../pallas/triton/pallas_call_registration.py | 10 ++----- jax/_src/test_util.py | 6 ++++ jax/experimental/pallas/gpu.py | 8 ++---- .../pallas/export_back_compat_pallas_test.py | 3 +- tests/pallas/gpu_attention_test.py | 28 ++----------------- tests/pallas/ops_test.py | 24 +++++----------- tests/pallas/pallas_test.py | 21 ++------------ tests/sparse_nm_test.py | 7 +---- 10 files changed, 29 insertions(+), 100 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 594f8913a..c23d01c2d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ Remember to align the itemized text with the first line of an item within a list * Deprecations & removals * The ``kind`` argument to {func}`jax.numpy.sort` and {func}`jax.numpy.argsort` is now removed. Use `stable=True` or `stable=False` instead. + * Removed ``get_compute_capability`` from the ``jax.experimental.pallas.gpu`` + module. Use the ``compute_capability`` attribute of a GPU device, returned + by {func}`jax.devices` or {func}`jax.local_devices`, instead. ## jaxlib 0.4.28 diff --git a/jax/_src/pallas/triton/__init__.py b/jax/_src/pallas/triton/__init__.py index ecfb85536..adade4e8a 100644 --- a/jax/_src/pallas/triton/__init__.py +++ b/jax/_src/pallas/triton/__init__.py @@ -12,24 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Contains Triton-specific pallas modules.""" +"""Triton-specific Pallas APIs.""" -from jax._src.lib import gpu_triton as triton_kernel_call_lib from jax._src.pallas.triton.primitives import approx_tanh from jax._src.pallas.triton.primitives import elementwise_inline_asm - - -try: - get_compute_capability = triton_kernel_call_lib.get_compute_capability -except AttributeError: - - def get_compute_capability(device) -> int: - del device # Unused. - raise RuntimeError( - "get_compute_capability is not available. Try installing jaxlib with" - " GPU support following instructions in" - " https://jax.readthedocs.io/en/latest/installation.html." - ) - - -del triton_kernel_call_lib diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index ec6ff3a77..f3584e6f5 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -25,7 +25,6 @@ from typing import Any import jax from jax import core as jax_core from jax._src.interpreters import mlir -from jax._src.lib import gpu_triton as triton_kernel_call_lib from jax._src.lib.mlir import ir from jax._src.pallas import core as pallas_core from jax._src.pallas.pallas_call import pallas_call_p @@ -58,13 +57,10 @@ def _pallas_call_ttir_lowering( num_warps: int, num_stages: int, ): - # TODO(sharadmv): handle multiple devices, right now we assume device 0 - # which is fine when we have multiple of the same GPU but this won't work in - # general. - device = 0 - compute_capability = triton_kernel_call_lib.get_compute_capability(device) + # TODO(sharadmv): Handle multiple devices with different capabilities. + d, *_ = jax.local_devices(backend="gpu") cuda_options = dict( - compute_capability=compute_capability, + compute_capability=d.compute_capability, num_warps=num_warps, num_stages=num_stages, debug=debug, diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 411347153..7fc4b36d5 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -419,6 +419,12 @@ def is_device_tpu(version: int | None = None, variant: str = "") -> bool: return "v5 lite" in device_kind return expected_version in device_kind +def is_device_gpu_at_least(capability: str) -> bool: + if device_under_test() != "gpu": + return False + d, *_ = jax.local_devices(backend="gpu") + return d.compute_capability >= capability + def _get_device_tags(): """returns a set of tags defined for the device under test""" if is_device_rocm(): diff --git a/jax/experimental/pallas/gpu.py b/jax/experimental/pallas/gpu.py index 54e2cb6ba..8047aeed2 100644 --- a/jax/experimental/pallas/gpu.py +++ b/jax/experimental/pallas/gpu.py @@ -12,11 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Contains Triton specific Pallas functions.""" -from jax._src.pallas import triton +"""Triton-specific Pallas APIs.""" + from jax._src.pallas.triton import approx_tanh from jax._src.pallas.triton import elementwise_inline_asm - -get_compute_capability = triton.get_compute_capability - -del triton diff --git a/tests/pallas/export_back_compat_pallas_test.py b/tests/pallas/export_back_compat_pallas_test.py index edf3a7425..3b8b1e07c 100644 --- a/tests/pallas/export_back_compat_pallas_test.py +++ b/tests/pallas/export_back_compat_pallas_test.py @@ -26,7 +26,6 @@ from jax._src import test_util as jtu from jax._src.internal_test_util import export_back_compat_test_util as bctu from jax._src.internal_test_util.export_back_compat_test_data.pallas import cuda_add_one from jax.experimental import pallas as pl -from jax.experimental.pallas import gpu as plgpu config.parse_flags_with_absl() @@ -40,7 +39,7 @@ class CompatTest(bctu.CompatTestBase): if not jtu.test_device_matches(["gpu"]): self.skipTest("Only works on GPU") if (jtu.test_device_matches(["cuda"]) and - plgpu.get_compute_capability(0) < 80): + not jtu.is_device_gpu_at_least("8.0")): self.skipTest("Only works on GPUs with capability >= sm80") super().setUp() diff --git a/tests/pallas/gpu_attention_test.py b/tests/pallas/gpu_attention_test.py index 3ff740227..2731ced04 100644 --- a/tests/pallas/gpu_attention_test.py +++ b/tests/pallas/gpu_attention_test.py @@ -13,7 +13,6 @@ # limitations under the License. import os -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -28,10 +27,6 @@ import numpy as np os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" -try: - from jax.experimental.pallas import gpu as plgpu -except ImportError: - pass # pylint: disable=no-value-for-parameter @@ -39,21 +34,12 @@ config.update("jax_traceback_filtering", "off") config.parse_flags_with_absl() -class PallasTest(jtu.JaxTestCase): - - def check_gpu_capability_at_least(self, capability, device: int = 0): - return plgpu.get_compute_capability(device) >= capability +class DecodeAttentionTest(jtu.JaxTestCase): def setUp(self): - if not jtu.test_device_matches(["gpu"]): - self.skipTest("Only works on GPU") - try: - import triton # noqa: F401 - except ImportError: - self.skipTest("Triton is not installed. Skipping PallasTest.") super().setUp() - -class DecodeAttentionTest(PallasTest): + if not jtu.is_device_gpu_at_least("8.0"): + self.skipTest("Fused attention only works on GPUs with capability >= sm80") @parameterized.named_parameters(*[ ( @@ -86,10 +72,6 @@ class DecodeAttentionTest(PallasTest): kwargs, ): del kwargs - if not self.check_gpu_capability_at_least(80): - raise unittest.SkipTest( - "Fused attention only works on GPUs with capability >= sm80" - ) k1, k2, k3 = random.split(random.key(0), 3) q = random.normal(k1, (batch_size, num_heads, head_dim), dtype=jnp.float16) @@ -134,10 +116,6 @@ class DecodeAttentionTest(PallasTest): kwargs, ): del kwargs - if not self.check_gpu_capability_at_least(80): - raise unittest.SkipTest( - "Fused attention only works on GPUs with capability >= sm80" - ) k1, k2, k3 = random.split(random.key(0), 3) q = random.normal( diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 72c49bbe8..0ce8a72bb 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -25,10 +25,7 @@ import jax.numpy as jnp from jax import lax from jax._src import test_util as jtu from jax.experimental import pallas as pl -try: - from jax.experimental.pallas import gpu as plgpu -except ImportError: - plgpu = None + jax.config.parse_flags_with_absl() @@ -40,19 +37,12 @@ class OpsTest(jtu.JaxTestCase): super().setUp() if jax.config.x64_enabled: self.skipTest("Only works in 32-bit") - if jtu.device_under_test() == "cpu" and not self.INTERPRET: - self.skipTest("Only interpreter mode supported on CPU") - if (jtu.test_device_matches(["cuda"]) and - not self.check_gpu_capability_at_least(80)): - self.skipTest("Only works on GPUs with capability >= sm80") - - def check_gpu_capability_at_least(self, capability, - device: int = 0): - if plgpu is None: - return False - if self.INTERPRET: - return True - return plgpu.get_compute_capability(device) >= capability + if not self.INTERPRET: + if jtu.device_under_test() == "cpu": + self.skipTest("Only interpreter mode supported on CPU") + if (jtu.test_device_matches(["cuda"]) and + not jtu.is_device_gpu_at_least("8.0")): + self.skipTest("Only works on GPUs with capability >= sm80") @classmethod def pallas_call(cls, *args, **kwargs): diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index d49301a01..5b1c068ef 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -134,7 +134,7 @@ class PallasTest(parameterized.TestCase): if not jtu.test_device_matches(["gpu"]): self.skipTest("Only works on GPU") if (jtu.test_device_matches(["cuda"]) and - not self.check_gpu_capability_at_least(80)): + not jtu.is_device_gpu_at_least("8.0")): self.skipTest("Only works on GPUs with capability >= sm80") super().setUp() @@ -143,12 +143,6 @@ class PallasTest(parameterized.TestCase): def pallas_call(self, *args, **kwargs): return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) - def check_gpu_capability_at_least(self, capability, - device: int = 0): - if self.INTERPRET: - return True - return plgpu.get_compute_capability(device) >= capability - class PallasCallTest(PallasTest): @@ -336,11 +330,6 @@ class PallasCallTest(PallasTest): if block_size_m <= m and block_size_n <= n and block_size_k <= k ]) def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm): - if not self.INTERPRET and ( - plgpu.get_compute_capability(0) <= 75 - and (bm >= 128 or bn > 128 or bk > 32) - ): - raise unittest.SkipTest("Block sizes too big for sm70.") k1, k2 = random.split(random.key(0)) x = random.normal(k1, (m, k), dtype=dtype) y = random.normal(k2, (k, n), dtype=dtype) @@ -362,12 +351,6 @@ class PallasCallTest(PallasTest): if block_size_m <= m and block_size_n <= n and block_size_k <= k ]) def test_matmul_block_spec(self, m, n, k, dtype, bm, bn, bk): - if not self.INTERPRET and ( - plgpu.get_compute_capability(0) <= 75 - and (bm >= 128 or bn > 128 or bk > 32) - ): - raise unittest.SkipTest("Block sizes too big for sm70.") - k1, k2 = random.split(random.key(0)) x = random.normal(k1, (m, k), dtype=dtype) y = random.normal(k2, (k, n), dtype=dtype) @@ -1666,7 +1649,7 @@ class PallasOpsTest(PallasTest): def test_approx_tanh(self, dtype): if self.INTERPRET: self.skipTest("approx_tanh is not supported in interpreter mode") - if dtype == "bfloat16" and not self.check_gpu_capability_at_least(90): + if dtype == "bfloat16" and not jtu.is_device_gpu_at_least("9.0"): self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") @functools.partial( diff --git a/tests/sparse_nm_test.py b/tests/sparse_nm_test.py index 7bffc5ca0..11097ac13 100644 --- a/tests/sparse_nm_test.py +++ b/tests/sparse_nm_test.py @@ -22,7 +22,6 @@ import jax import jax.numpy as jnp from jax import dtypes from jax._src import test_util as jtu -from jax.experimental.pallas import gpu as plgpu from jax.experimental.sparse import nm jax.config.parse_flags_with_absl() @@ -33,14 +32,10 @@ class SpmmTest(jtu.JaxTestCase): if not jtu.test_device_matches(["gpu"]): self.skipTest("Only works on GPU") if (jtu.test_device_matches(["cuda"]) and - not self.check_gpu_capability_at_least(80)): + not jtu.is_device_gpu_at_least("8.0")): self.skipTest("Only works on GPUs with capability >= sm80") super().setUp() - def check_gpu_capability_at_least(self, capability, - device: int = 0): - return plgpu.get_compute_capability(device) >= capability - # ----- Test different input shapes @parameterized.product( tile_m=(32, 128),