mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Removed get_compute_capability from jax.experimental.pallas.gpu
Compute capability is available as a `str` attribute on a GPU device since jaxlib 0.4.26.
This commit is contained in:
parent
a145109ac2
commit
575ba942e0
@ -11,6 +11,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Deprecations & removals
|
||||
* The ``kind`` argument to {func}`jax.numpy.sort` and {func}`jax.numpy.argsort`
|
||||
is now removed. Use `stable=True` or `stable=False` instead.
|
||||
* Removed ``get_compute_capability`` from the ``jax.experimental.pallas.gpu``
|
||||
module. Use the ``compute_capability`` attribute of a GPU device, returned
|
||||
by {func}`jax.devices` or {func}`jax.local_devices`, instead.
|
||||
|
||||
## jaxlib 0.4.28
|
||||
|
||||
|
@ -12,24 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains Triton-specific pallas modules."""
|
||||
"""Triton-specific Pallas APIs."""
|
||||
|
||||
from jax._src.lib import gpu_triton as triton_kernel_call_lib
|
||||
from jax._src.pallas.triton.primitives import approx_tanh
|
||||
from jax._src.pallas.triton.primitives import elementwise_inline_asm
|
||||
|
||||
|
||||
try:
|
||||
get_compute_capability = triton_kernel_call_lib.get_compute_capability
|
||||
except AttributeError:
|
||||
|
||||
def get_compute_capability(device) -> int:
|
||||
del device # Unused.
|
||||
raise RuntimeError(
|
||||
"get_compute_capability is not available. Try installing jaxlib with"
|
||||
" GPU support following instructions in"
|
||||
" https://jax.readthedocs.io/en/latest/installation.html."
|
||||
)
|
||||
|
||||
|
||||
del triton_kernel_call_lib
|
||||
|
@ -25,7 +25,6 @@ from typing import Any
|
||||
import jax
|
||||
from jax import core as jax_core
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import gpu_triton as triton_kernel_call_lib
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.pallas_call import pallas_call_p
|
||||
@ -58,13 +57,10 @@ def _pallas_call_ttir_lowering(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
):
|
||||
# TODO(sharadmv): handle multiple devices, right now we assume device 0
|
||||
# which is fine when we have multiple of the same GPU but this won't work in
|
||||
# general.
|
||||
device = 0
|
||||
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
|
||||
# TODO(sharadmv): Handle multiple devices with different capabilities.
|
||||
d, *_ = jax.local_devices(backend="gpu")
|
||||
cuda_options = dict(
|
||||
compute_capability=compute_capability,
|
||||
compute_capability=d.compute_capability,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
debug=debug,
|
||||
|
@ -419,6 +419,12 @@ def is_device_tpu(version: int | None = None, variant: str = "") -> bool:
|
||||
return "v5 lite" in device_kind
|
||||
return expected_version in device_kind
|
||||
|
||||
def is_device_gpu_at_least(capability: str) -> bool:
|
||||
if device_under_test() != "gpu":
|
||||
return False
|
||||
d, *_ = jax.local_devices(backend="gpu")
|
||||
return d.compute_capability >= capability
|
||||
|
||||
def _get_device_tags():
|
||||
"""returns a set of tags defined for the device under test"""
|
||||
if is_device_rocm():
|
||||
|
@ -12,11 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains Triton specific Pallas functions."""
|
||||
from jax._src.pallas import triton
|
||||
"""Triton-specific Pallas APIs."""
|
||||
|
||||
from jax._src.pallas.triton import approx_tanh
|
||||
from jax._src.pallas.triton import elementwise_inline_asm
|
||||
|
||||
get_compute_capability = triton.get_compute_capability
|
||||
|
||||
del triton
|
||||
|
@ -26,7 +26,6 @@ from jax._src import test_util as jtu
|
||||
from jax._src.internal_test_util import export_back_compat_test_util as bctu
|
||||
from jax._src.internal_test_util.export_back_compat_test_data.pallas import cuda_add_one
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -40,7 +39,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
if not jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Only works on GPU")
|
||||
if (jtu.test_device_matches(["cuda"]) and
|
||||
plgpu.get_compute_capability(0) < 80):
|
||||
not jtu.is_device_gpu_at_least("8.0")):
|
||||
self.skipTest("Only works on GPUs with capability >= sm80")
|
||||
super().setUp()
|
||||
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -28,10 +27,6 @@ import numpy as np
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
|
||||
|
||||
try:
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
except ImportError:
|
||||
pass
|
||||
# pylint: disable=no-value-for-parameter
|
||||
|
||||
|
||||
@ -39,21 +34,12 @@ config.update("jax_traceback_filtering", "off")
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class PallasTest(jtu.JaxTestCase):
|
||||
|
||||
def check_gpu_capability_at_least(self, capability, device: int = 0):
|
||||
return plgpu.get_compute_capability(device) >= capability
|
||||
class DecodeAttentionTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Only works on GPU")
|
||||
try:
|
||||
import triton # noqa: F401
|
||||
except ImportError:
|
||||
self.skipTest("Triton is not installed. Skipping PallasTest.")
|
||||
super().setUp()
|
||||
|
||||
class DecodeAttentionTest(PallasTest):
|
||||
if not jtu.is_device_gpu_at_least("8.0"):
|
||||
self.skipTest("Fused attention only works on GPUs with capability >= sm80")
|
||||
|
||||
@parameterized.named_parameters(*[
|
||||
(
|
||||
@ -86,10 +72,6 @@ class DecodeAttentionTest(PallasTest):
|
||||
kwargs,
|
||||
):
|
||||
del kwargs
|
||||
if not self.check_gpu_capability_at_least(80):
|
||||
raise unittest.SkipTest(
|
||||
"Fused attention only works on GPUs with capability >= sm80"
|
||||
)
|
||||
|
||||
k1, k2, k3 = random.split(random.key(0), 3)
|
||||
q = random.normal(k1, (batch_size, num_heads, head_dim), dtype=jnp.float16)
|
||||
@ -134,10 +116,6 @@ class DecodeAttentionTest(PallasTest):
|
||||
kwargs,
|
||||
):
|
||||
del kwargs
|
||||
if not self.check_gpu_capability_at_least(80):
|
||||
raise unittest.SkipTest(
|
||||
"Fused attention only works on GPUs with capability >= sm80"
|
||||
)
|
||||
|
||||
k1, k2, k3 = random.split(random.key(0), 3)
|
||||
q = random.normal(
|
||||
|
@ -25,10 +25,7 @@ import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental import pallas as pl
|
||||
try:
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
except ImportError:
|
||||
plgpu = None
|
||||
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
@ -40,19 +37,12 @@ class OpsTest(jtu.JaxTestCase):
|
||||
super().setUp()
|
||||
if jax.config.x64_enabled:
|
||||
self.skipTest("Only works in 32-bit")
|
||||
if jtu.device_under_test() == "cpu" and not self.INTERPRET:
|
||||
self.skipTest("Only interpreter mode supported on CPU")
|
||||
if (jtu.test_device_matches(["cuda"]) and
|
||||
not self.check_gpu_capability_at_least(80)):
|
||||
self.skipTest("Only works on GPUs with capability >= sm80")
|
||||
|
||||
def check_gpu_capability_at_least(self, capability,
|
||||
device: int = 0):
|
||||
if plgpu is None:
|
||||
return False
|
||||
if self.INTERPRET:
|
||||
return True
|
||||
return plgpu.get_compute_capability(device) >= capability
|
||||
if not self.INTERPRET:
|
||||
if jtu.device_under_test() == "cpu":
|
||||
self.skipTest("Only interpreter mode supported on CPU")
|
||||
if (jtu.test_device_matches(["cuda"]) and
|
||||
not jtu.is_device_gpu_at_least("8.0")):
|
||||
self.skipTest("Only works on GPUs with capability >= sm80")
|
||||
|
||||
@classmethod
|
||||
def pallas_call(cls, *args, **kwargs):
|
||||
|
@ -134,7 +134,7 @@ class PallasTest(parameterized.TestCase):
|
||||
if not jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Only works on GPU")
|
||||
if (jtu.test_device_matches(["cuda"]) and
|
||||
not self.check_gpu_capability_at_least(80)):
|
||||
not jtu.is_device_gpu_at_least("8.0")):
|
||||
self.skipTest("Only works on GPUs with capability >= sm80")
|
||||
|
||||
super().setUp()
|
||||
@ -143,12 +143,6 @@ class PallasTest(parameterized.TestCase):
|
||||
def pallas_call(self, *args, **kwargs):
|
||||
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
|
||||
|
||||
def check_gpu_capability_at_least(self, capability,
|
||||
device: int = 0):
|
||||
if self.INTERPRET:
|
||||
return True
|
||||
return plgpu.get_compute_capability(device) >= capability
|
||||
|
||||
|
||||
class PallasCallTest(PallasTest):
|
||||
|
||||
@ -336,11 +330,6 @@ class PallasCallTest(PallasTest):
|
||||
if block_size_m <= m and block_size_n <= n and block_size_k <= k
|
||||
])
|
||||
def test_matmul(self, m, n, k, dtype, bm, bn, bk, gm):
|
||||
if not self.INTERPRET and (
|
||||
plgpu.get_compute_capability(0) <= 75
|
||||
and (bm >= 128 or bn > 128 or bk > 32)
|
||||
):
|
||||
raise unittest.SkipTest("Block sizes too big for sm70.")
|
||||
k1, k2 = random.split(random.key(0))
|
||||
x = random.normal(k1, (m, k), dtype=dtype)
|
||||
y = random.normal(k2, (k, n), dtype=dtype)
|
||||
@ -362,12 +351,6 @@ class PallasCallTest(PallasTest):
|
||||
if block_size_m <= m and block_size_n <= n and block_size_k <= k
|
||||
])
|
||||
def test_matmul_block_spec(self, m, n, k, dtype, bm, bn, bk):
|
||||
if not self.INTERPRET and (
|
||||
plgpu.get_compute_capability(0) <= 75
|
||||
and (bm >= 128 or bn > 128 or bk > 32)
|
||||
):
|
||||
raise unittest.SkipTest("Block sizes too big for sm70.")
|
||||
|
||||
k1, k2 = random.split(random.key(0))
|
||||
x = random.normal(k1, (m, k), dtype=dtype)
|
||||
y = random.normal(k2, (k, n), dtype=dtype)
|
||||
@ -1666,7 +1649,7 @@ class PallasOpsTest(PallasTest):
|
||||
def test_approx_tanh(self, dtype):
|
||||
if self.INTERPRET:
|
||||
self.skipTest("approx_tanh is not supported in interpreter mode")
|
||||
if dtype == "bfloat16" and not self.check_gpu_capability_at_least(90):
|
||||
if dtype == "bfloat16" and not jtu.is_device_gpu_at_least("9.0"):
|
||||
self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90")
|
||||
|
||||
@functools.partial(
|
||||
|
@ -22,7 +22,6 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental.pallas import gpu as plgpu
|
||||
from jax.experimental.sparse import nm
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
@ -33,14 +32,10 @@ class SpmmTest(jtu.JaxTestCase):
|
||||
if not jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Only works on GPU")
|
||||
if (jtu.test_device_matches(["cuda"]) and
|
||||
not self.check_gpu_capability_at_least(80)):
|
||||
not jtu.is_device_gpu_at_least("8.0")):
|
||||
self.skipTest("Only works on GPUs with capability >= sm80")
|
||||
super().setUp()
|
||||
|
||||
def check_gpu_capability_at_least(self, capability,
|
||||
device: int = 0):
|
||||
return plgpu.get_compute_capability(device) >= capability
|
||||
|
||||
# ----- Test different input shapes
|
||||
@parameterized.product(
|
||||
tile_m=(32, 128),
|
||||
|
Loading…
x
Reference in New Issue
Block a user