mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00

In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases. The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested. Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected. To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.) With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`. Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this. One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach. PiperOrigin-RevId: 683302687
117 lines
4.3 KiB
Python
117 lines
4.3 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Sparse utilities."""
|
|
|
|
import functools
|
|
from typing import Any, NamedTuple, Union
|
|
|
|
import numpy as np
|
|
import jax
|
|
from jax import lax
|
|
from jax import tree_util
|
|
from jax import vmap
|
|
from jax._src import core
|
|
from jax._src import dtypes
|
|
from jax._src import stages
|
|
from jax._src.api_util import flatten_axes
|
|
import jax.numpy as jnp
|
|
from jax.util import safe_zip
|
|
from jax._src.lax.lax import _dot_general_shape_rule, DotDimensionNumbers
|
|
from jax._src.typing import Array
|
|
|
|
class SparseEfficiencyError(ValueError):
|
|
pass
|
|
|
|
class SparseEfficiencyWarning(UserWarning):
|
|
pass
|
|
|
|
class CuSparseEfficiencyWarning(SparseEfficiencyWarning):
|
|
pass
|
|
|
|
Shape = tuple[int, ...]
|
|
|
|
class SparseInfo(NamedTuple):
|
|
shape: Shape
|
|
indices_sorted: bool = False
|
|
unique_indices: bool = False
|
|
|
|
#--------------------------------------------------------------------
|
|
# utilities
|
|
# TODO: possibly make these primitives, targeting cusparse routines
|
|
# csr2coo/coo2csr/SPDDMM
|
|
|
|
def nfold_vmap(fun, N, *, broadcasted=True, in_axes=0):
|
|
"""Convenience function to apply (broadcasted) vmap N times."""
|
|
_vmap = broadcasting_vmap if broadcasted else vmap
|
|
for _ in range(N):
|
|
fun = _vmap(fun, in_axes=in_axes)
|
|
return fun
|
|
|
|
def broadcasting_vmap(fun, in_axes=0, out_axes=0):
|
|
@functools.wraps(fun)
|
|
def batched_fun(*args):
|
|
args_flat, in_tree = tree_util.tree_flatten(args)
|
|
in_axes_flat = flatten_axes("vmap in_axes", in_tree, in_axes, kws=False)
|
|
size = max(arg.shape[i] for arg, i in safe_zip(args_flat, in_axes_flat) if i is not None)
|
|
if size > 1:
|
|
if any(i is not None and arg.shape[i] not in (1, size)
|
|
for arg, i in safe_zip(args_flat, in_axes_flat)):
|
|
raise ValueError("broadcasting_vmap: mismatched input shapes")
|
|
args_flat, in_axes_flat = zip(*(
|
|
(arg, None) if i is None else (lax.squeeze(arg, (i,)), None) if arg.shape[i] == 1 else (arg, i)
|
|
for arg, i in zip(args_flat, in_axes_flat)
|
|
))
|
|
new_args = tree_util.tree_unflatten(in_tree, args_flat)
|
|
new_in_axes = tree_util.tree_unflatten(in_tree, in_axes_flat)
|
|
return vmap(fun, in_axes=new_in_axes, out_axes=out_axes)(*new_args)
|
|
return batched_fun
|
|
|
|
@jax.jit
|
|
def _csr_to_coo(indices: Array, indptr: Array) -> tuple[Array, Array]:
|
|
"""Given CSR (indices, indptr) return COO (row, col)"""
|
|
return jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1, indices
|
|
|
|
def _csr_extract(indices: Array, indptr: Array, mat: Array) -> Array:
|
|
"""Extract values of dense matrix mat at given CSR indices."""
|
|
row, col = _csr_to_coo(indices, indptr)
|
|
return _coo_extract(row, col, mat)
|
|
|
|
def _coo_extract(row: Array, col: Array, mat: Array) -> Array:
|
|
"""Extract values of dense matrix mat at given COO indices."""
|
|
return mat[row, col]
|
|
|
|
def _count_stored_elements_per_batch(mat: Array, n_batch: int = 0, n_dense: int = 0) -> Array:
|
|
"""Return per-batch number of stored elements (nse) of a dense matrix."""
|
|
mat = jnp.asarray(mat)
|
|
mask = (mat != 0)
|
|
if n_dense > 0:
|
|
mask = mask.any(tuple(-(i + 1) for i in range(n_dense)))
|
|
mask = mask.sum(tuple(range(n_batch, mask.ndim)))
|
|
return mask
|
|
|
|
def _count_stored_elements(mat: Array, n_batch: int = 0, n_dense: int = 0) -> Array:
|
|
"""Return the number of stored elements (nse) of the given dense matrix."""
|
|
return _count_stored_elements_per_batch(mat, n_batch, n_dense).max(initial=0)
|
|
|
|
def _dot_general_validated_shape(
|
|
lhs_shape: tuple[int, ...], rhs_shape: tuple[int, ...],
|
|
dimension_numbers: DotDimensionNumbers) -> tuple[int, ...]:
|
|
"""Validate the inputs and return the output shape."""
|
|
lhs = core.ShapedArray(lhs_shape, np.float32)
|
|
rhs = core.ShapedArray(rhs_shape, np.float32)
|
|
return _dot_general_shape_rule(
|
|
lhs, rhs, dimension_numbers=dimension_numbers,
|
|
precision=None, preferred_element_type=None)
|