Add JAX API that provides sparse matmul support (2:4 structured sparsity)

Usage:
from jax.experimental.sparse import nm
res = nm.nm_spmm(lhs, rhs, nm.nm_pack(mask))

where:
lhs.shape = [M, K/2]
rhs.shape = [K, N]
`mask` has the same shape as `lhs` with boolean type

If batch dimensions are present, the `dimension_numbers` argument has to be set to:
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))

The lowering only works on nVidia GPUs, that provide hardware support for sparse dots.

PiperOrigin-RevId: 627640553
This commit is contained in:
Sergey Kozub 2024-04-24 01:05:45 -07:00 committed by jax authors
parent b5fdc0d90f
commit aebe82a78f
3 changed files with 458 additions and 0 deletions

View File

@ -0,0 +1,241 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""N:M-sparsity associated primitives."""
from jax import core
from jax._src import dispatch
from jax._src.lax.lax import DotDimensionNumbers
from jax._src.lib import gpu_sparse
from jax._src.lib.mlir.dialects import mhlo
from jax._src.typing import Array, DTypeLike
from jax.interpreters import mlir
import jax.numpy as jnp
import numpy as np
# --------------------------------------------------------------------
# nm_spmm
nm_spmm_p = core.Primitive("sparse_dense_matmul")
_supported_input_types = (jnp.int8, jnp.int16, jnp.float16, jnp.bfloat16)
_supported_output_types = (jnp.bfloat16, jnp.float32)
def nm_spmm(
lhs: Array,
rhs: Array,
metadata: Array,
dimension_numbers: DotDimensionNumbers = (((1,), (0,)), (tuple(), tuple())),
sparse_operand_idx: int = 0,
output_dtype: DTypeLike = jnp.bfloat16,
) -> Array:
"""Dot operation where one of the operands has N:M sparsity.
Args:
lhs: An ndarray (first dot operand).
rhs: An ndarray (second dot operand).
metadata: An ndarray with structured sparsity metadata for the contracting
dimension. For 2:4 sparsity it should contain (N=2) two-bit index values
for each (M=4) element group.
dimension_numbers: a tuple of tuples of the form `((lhs_contracting_dims,
rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
sparse_operand_idx: index of the sparse operand (0 or 1).
output_dtype: result type.
Returns:
An ndarray dense array containing the result.
"""
return nm_spmm_p.bind(
lhs,
rhs,
metadata,
dimension_numbers=dimension_numbers,
sparse_operand_idx=sparse_operand_idx,
output_dtype=output_dtype,
)
def _calc_groups_per_element(n, m):
group_bits = n * (m.bit_length() - 1) # 4 bits per group for 2:4
return 16 // group_bits
def _validate_dnums(rank, contract, batch, name):
non_contract = tuple(sorted(set(range(rank)) - set(contract + batch)))
if sorted(non_contract + contract + batch) != list(range(rank)):
raise TypeError(f"Incorrect dimension numbers for {name}")
return non_contract
def _validate_metadata(lhs, rhs, metadata, dimension_numbers, index, n=2, m=4):
assert index in (0, 1)
size_factor = n * _calc_groups_per_element(n, m)
sparse = [lhs, rhs][index]
sparse_contract = dimension_numbers[0][index]
if metadata.dtype != np.uint16:
raise TypeError(f"Metadata must be uint16, got {metadata.dtype}")
if sparse_contract[0] != sparse.ndim - 1:
raise TypeError("Contracting dimension must be the minor one")
if metadata.shape[:-1] != sparse.shape[:-1]:
raise TypeError(
"Metadata shape must match the operand shape (except for the"
" contracting dimension)"
)
if metadata.shape[-1] * size_factor != sparse.shape[-1]:
raise TypeError(
f"Metadata must be exactly {size_factor} times less than the"
f" contracting dimension for {n}:{m} structured sparsity (expected"
f" {sparse.shape[-1] // size_factor}, got {metadata.shape[-1]})"
)
if sparse.shape[-1] % size_factor != 0:
raise NotImplementedError("Metadata with padding is not supported")
dense = [lhs, rhs][1 - index]
dense_contract = dimension_numbers[0][1 - index]
a, b = sparse.shape[sparse_contract[0]], dense.shape[dense_contract[0]]
if n * b != m * a:
raise TypeError(
f"Contracting dimension sizes should have {n}:{m} ratio, got {a}:{b}"
)
def _infer_result_shape(lhs, rhs, dimension_numbers):
((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) = dimension_numbers
if len(lhs_contract) != 1 or len(rhs_contract) != 1:
raise TypeError("Only single contracting dimension is supported")
lhs_dims = _validate_dnums(lhs.ndim, lhs_contract, lhs_batch, "lhs")
rhs_dims = _validate_dnums(rhs.ndim, rhs_contract, rhs_batch, "rhs")
if len(lhs_dims) != 1 or len(rhs_dims) != 1:
raise TypeError("Only single non-contracting dimension is supported")
batch = [lhs.shape[i] for i in lhs_batch]
if batch != [rhs.shape[i] for i in rhs_batch]:
raise TypeError("Batch dimension sizes do not match")
return tuple(batch + [lhs.shape[lhs_dims[0]], rhs.shape[rhs_dims[0]]])
def _nm_spmm_default_lowering(*_args, **_kwargs):
raise NotImplementedError("Sparse N:M matmul is only implemented on GPU")
def _nm_spmm_gpu_lowering(
ctx,
lhs,
rhs,
metadata,
*,
dimension_numbers,
sparse_operand_idx,
output_dtype,
):
assert sparse_operand_idx in (0, 1)
sparsity_descriptor = mhlo.SparsityDescriptor.get(
dimension=dimension_numbers[0][sparse_operand_idx][0], n=2, m=4
)
dot_dnums = mhlo.DotDimensionNumbers.get(
lhs_batching_dimensions=dimension_numbers[1][sparse_operand_idx],
rhs_batching_dimensions=dimension_numbers[1][1 - sparse_operand_idx],
lhs_contracting_dimensions=dimension_numbers[0][sparse_operand_idx],
rhs_contracting_dimensions=dimension_numbers[0][1 - sparse_operand_idx],
)
dot_type = ctx.avals_out[0]
key = ["lhs_sparsity", "rhs_sparsity"][sparse_operand_idx]
kwargs = {key: sparsity_descriptor}
op = mhlo.SparseDotOp(
mlir.aval_to_ir_type(dot_type), lhs, rhs, [metadata], dot_dnums, **kwargs
)
return op.results
@nm_spmm_p.def_abstract_eval
def _nm_spmm_abstract_eval(
lhs, rhs, metadata, *, dimension_numbers, sparse_operand_idx, output_dtype
):
if lhs.dtype not in _supported_input_types:
raise TypeError(f"Unsupported lhs input type: {lhs.dtype}")
if rhs.dtype not in _supported_input_types:
raise TypeError(f"Unsupported rhs input type: {rhs.dtype}")
if output_dtype not in _supported_output_types:
raise TypeError(f"Unsupported output type: {output_dtype}")
res_shape = _infer_result_shape(lhs, rhs, dimension_numbers)
_validate_metadata(lhs, rhs, metadata, dimension_numbers, sparse_operand_idx)
return core.ShapedArray(res_shape, output_dtype)
mlir.register_lowering(nm_spmm_p, _nm_spmm_default_lowering)
dispatch.simple_impl(nm_spmm_p)
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="cuda")
# --------------------------------------------------------------------
# nm_pack
nm_pack_p = core.Primitive("sparse_pack_nm")
def nm_pack(mask: Array, n=2, m=4) -> Array:
"""Generate metadata tensor for an N:M mask.
Args:
mask: Predicates for the input tensor, where the elements are grouped in the
minor dimension. In each group of size M there should be exactly N true
values, which mark the data elements to keep.
n: Number of non-zero elements in a group.
m: Group size.
Returns:
An ndarray containing only the masked input elements.
"""
return nm_pack_p.bind(mask, n=n, m=m)
def _compress(data, n, m, k):
result = []
expected = n * (k // m)
for i in range(0, len(data), k):
index = tuple(jnp.nonzero(data[i : i + k], size=expected)[0] % m)
value = sum(j * pow(m, i) for i, j in enumerate(index))
result.append(value)
return jnp.array(result, dtype=np.uint16)
@nm_pack_p.def_impl
def _nm_pack_impl(mask, *, n, m):
batch_size = m * _calc_groups_per_element(n, m)
return jnp.apply_along_axis(
lambda x: _compress(x, n, m, batch_size), -1, mask
)
@nm_pack_p.def_abstract_eval
def _nm_pack_abstract_eval(mask, *, n, m):
size_factor = m * _calc_groups_per_element(n, m)
if mask.dtype != bool:
raise TypeError(f"Mask should be bool, got {mask.dtype}")
if mask.shape[-1] % size_factor != 0:
raise TypeError(
f"Inner dimension size should be divisible by {size_factor}, got"
f" {mask.shape}"
)
res_shape = list(mask.shape)
res_shape[-1] //= size_factor
return core.ShapedArray(res_shape, np.uint16)
_nm_pack_lowering = mlir.lower_fun(_nm_pack_impl, multiple_results=False)
mlir.register_lowering(nm_pack_p, _nm_pack_lowering)
dispatch.simple_impl(nm_pack_p)

View File

@ -984,6 +984,23 @@ jax_test(
] + py_deps("scipy"),
)
jax_test(
name = "sparse_nm_test",
srcs = ["sparse_nm_test.py"],
disable_backends = [
"cpu",
"gpu",
"tpu",
],
enable_configs = [
"gpu_a100",
"gpu_h100",
],
deps = [
"//jax:experimental_sparse",
],
)
jax_test(
name = "sparsify_test",
srcs = ["sparsify_test.py"],

200
tests/sparse_nm_test.py Normal file
View File

@ -0,0 +1,200 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import dtypes
from jax._src import test_util as jtu
from jax.experimental.sparse import nm
import jax.numpy as jnp
import numpy as np
jax.config.parse_flags_with_absl()
class SpmmTest(jtu.JaxTestCase):
# ----- Test different input shapes
@parameterized.product(
tile_m=(32, 128),
tile_n=(32, 128),
tile_k=(32, 128),
batch=(None, 5),
sparse_idx=(0, 1),
)
@jtu.run_on_devices("gpu")
def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx):
# Build keyword arguments
kwargs = {
"dimension_numbers": (((1,), (1,)), (tuple(), tuple())),
"sparse_operand_idx": sparse_idx,
}
if batch:
kwargs["dimension_numbers"] = (((2,), (2,)), ((0,), (0,)))
# Build input data
batch_dims = (batch,) if batch else tuple()
lhs = (
(np.arange((batch or 1) * tile_m * tile_k) % 11)
.astype(dtypes.bfloat16)
.reshape(batch_dims + (tile_m, tile_k))
)
rhs = (
(np.arange((batch or 1) * tile_n * tile_k) % 13)
.astype(dtypes.bfloat16)
.reshape(batch_dims + (tile_n, tile_k))
)
# Build sparsity mask and metadata
sp = [lhs, rhs][sparse_idx]
mask = np.tile([True, False], math.prod(sp.shape) // 2).reshape(sp.shape)
sparse = sp[mask].reshape(sp.shape[:-1] + (sp.shape[-1] // 2,))
meta = nm.nm_pack(mask)
# Calculate sparse and dense dots
if sparse_idx == 0:
dot_sparse = nm.nm_spmm(sparse, rhs, meta, **kwargs)
dot_dense = jnp.einsum("...mk,...nk->...mn", (lhs * mask), rhs)
else:
dot_sparse = nm.nm_spmm(lhs, sparse, meta, **kwargs)
dot_dense = jnp.einsum("...mk,...nk->...mn", lhs, (rhs * mask))
# Verify the result
jtu.check_eq(dot_sparse, dot_dense.astype(dtypes.bfloat16))
# ----- Test different input types
# TODO(b/336519663): add int8 type once codegen is fixed
@parameterized.product(
lhs_type=[jnp.int16, jnp.float16, jnp.bfloat16],
rhs_type=[jnp.bfloat16],
output_type=[jnp.bfloat16, jnp.float32],
)
@jtu.run_on_devices("gpu")
def test_types(self, lhs_type, rhs_type, output_type):
tile_m, tile_n, tile_k = 64, 32, 128
# Build input data
lhs = (
(np.arange(tile_m * tile_k) % 17)
.astype(lhs_type)
.reshape((tile_m, tile_k))
)
rhs = (
(np.arange(tile_k * tile_n) % 19)
.astype(rhs_type)
.reshape((tile_k, tile_n))
)
# Build sparsity mask and metadata
mask = np.tile([True, False], tile_m * tile_k // 2).reshape(lhs.shape)
sparse = lhs[mask].reshape(tile_m, tile_k // 2)
meta = nm.nm_pack(mask)
# Calculate sparse and dense dots
dot_sparse = nm.nm_spmm(sparse, rhs, meta, output_dtype=output_type)
dot_dense = (lhs * mask) @ rhs
# Verify the result
jtu.check_close(dot_sparse, dot_dense.astype(output_type), rtol=0.01)
# ----- Test validation
@jtu.run_on_devices("gpu")
def test_validate_nm_pack(self):
with self.assertRaisesRegex(TypeError, "Mask should be bool"):
nm.nm_pack(jnp.zeros(16, jnp.int8))
with self.assertRaisesRegex(
TypeError, "Inner dimension size should be divisible by 16"
):
nm.nm_pack(jnp.array([False] * 8))
@jtu.run_on_devices("gpu")
def test_validate_nm_spmm(self):
batch, tile_m, tile_n, tile_k = 2, 64, 32, 128
lhs = jnp.zeros((batch, tile_m, tile_k // 2), dtype=jnp.bfloat16)
rhs = jnp.zeros((batch, tile_k, tile_n), dtype=jnp.bfloat16)
meta = jnp.zeros((batch, tile_m, tile_k // 16), dtype=jnp.uint16)
# Check types
with self.assertRaisesRegex(TypeError, "Unsupported lhs input type"):
nm.nm_spmm(jnp.zeros(lhs.shape, dtype=jnp.int64), rhs, meta)
with self.assertRaisesRegex(TypeError, "Unsupported rhs input type"):
nm.nm_spmm(lhs, jnp.zeros(rhs.shape, dtype=jnp.int64), meta)
with self.assertRaisesRegex(TypeError, "Unsupported output type"):
nm.nm_spmm(lhs, rhs, meta, output_dtype=jnp.int64)
# Check dimension numbers
nm_spmm_with_dnums = lambda c, b: nm.nm_spmm(
lhs, rhs, meta, dimension_numbers=(c, b)
)
with self.assertRaisesRegex(
TypeError, "Only single contracting dimension is supported"
):
nm_spmm_with_dnums(((0, 2), (0, 1)), (tuple(), tuple()))
with self.assertRaisesRegex(
TypeError, "Incorrect dimension numbers for lhs"
):
nm_spmm_with_dnums(((2,), (1,)), ((2,), (0,)))
with self.assertRaisesRegex(
TypeError, "Incorrect dimension numbers for rhs"
):
nm_spmm_with_dnums(((2,), (1,)), ((0,), (1,)))
with self.assertRaisesRegex(
TypeError, "Only single non-contracting dimension is supported"
):
nm_spmm_with_dnums(((2,), (1,)), (tuple(), tuple()))
with self.assertRaisesRegex(
TypeError, "Batch dimension sizes do not match"
):
nm.nm_spmm(
lhs,
rhs.reshape(1, tile_k, tile_n * batch),
meta,
dimension_numbers=(((2,), (1,)), ((0,), (0,))),
)
# Check metadata
nm_spmm_with_meta = lambda m: nm.nm_spmm(
lhs, rhs, m, dimension_numbers=(((2,), (1,)), ((0,), (0,)))
)
with self.assertRaisesRegex(TypeError, "Metadata must be uint16"):
nm_spmm_with_meta(jnp.zeros(meta.shape, dtype=jnp.uint8))
with self.assertRaisesRegex(
TypeError, "Metadata shape must match the operand shape"
):
nm_spmm_with_meta(meta.reshape(1, batch * tile_m, tile_k // 16))
with self.assertRaisesRegex(
TypeError,
"Metadata must be exactly 8 times less than the contracting dimension"
" for 2:4 structured sparsity",
):
nm_spmm_with_meta(jnp.repeat(meta, 2, axis=-1))
with self.assertRaisesRegex(
TypeError, "Contracting dimension must be the minor one"
):
nm.nm_spmm(lhs, rhs, meta, dimension_numbers=(((1,), (1,)), ((0,), (0,))))
with self.assertRaisesRegex(
TypeError, "Contracting dimension sizes should have 2:4 ratio"
):
nm.nm_spmm(
lhs,
jnp.repeat(rhs, 2, axis=1),
meta,
dimension_numbers=(((2,), (1,)), ((0,), (0,))),
)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())