mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Merge pull request #26345 from wenscarl:scaled_matmul
PiperOrigin-RevId: 731865430
This commit is contained in:
commit
c7ca35fe32
696
jax/_src/cudnn/scaled_matmul_stablehlo.py
Normal file
696
jax/_src/cudnn/scaled_matmul_stablehlo.py
Normal file
@ -0,0 +1,696 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import operator
|
||||
from functools import partial, reduce
|
||||
from typing import List
|
||||
|
||||
# Third-party imports
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax import custom_vjp, lax
|
||||
from jax._src import core, dispatch, dtypes
|
||||
from jax._src.custom_partitioning import custom_partitioning
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.lax.lax import ranges_like, remaining
|
||||
from jax._src.typing import DTypeLike
|
||||
from jax.interpreters import mlir, xla
|
||||
from jax.interpreters.mlir import ir
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
|
||||
Array = jnp.ndarray
|
||||
block_scaled_dot_name = "__op$block_scaled_dot"
|
||||
|
||||
@dataclass
|
||||
class BlockScaleConfig:
|
||||
mode: str
|
||||
block_size: int
|
||||
data_type: DTypeLike
|
||||
scale_type: DTypeLike
|
||||
global_scale: Array | None
|
||||
infer_only: bool
|
||||
|
||||
def default_layouts(*shapes):
|
||||
return [range(len(shape) - 1, -1, -1) for shape in shapes]
|
||||
|
||||
def element_type_to_backend_config_type(dtype):
|
||||
_element_type_to_backend_config_type_mapping = {
|
||||
ir.BF16Type.get(): "BF16",
|
||||
ir.F16Type.get(): "F16",
|
||||
ir.F32Type.get(): "F32",
|
||||
}
|
||||
return _element_type_to_backend_config_type_mapping[dtype]
|
||||
|
||||
|
||||
def _scaled_matmul_impl(a, b, a_scale, b_scale, preferred_element_type):
|
||||
return _scaled_matmul_p.bind(
|
||||
a, b, a_scale, b_scale, preferred_element_type=preferred_element_type
|
||||
)
|
||||
|
||||
|
||||
def _scaled_matmul_cuda_lowering(
|
||||
ctx, a, b, a_scales, b_scales, preferred_element_type
|
||||
):
|
||||
lhs_type = ir.RankedTensorType(a.type)
|
||||
lhs_shape = lhs_type.shape
|
||||
rhs_type = ir.RankedTensorType(b.type)
|
||||
rhs_shape = rhs_type.shape
|
||||
|
||||
batch, non_contracting_lhs, contracting = lhs_shape
|
||||
_, non_contracting_rhs, _ = rhs_shape
|
||||
result_shape = (batch, non_contracting_lhs, non_contracting_rhs)
|
||||
|
||||
out_type = mlir.dtype_to_ir_type(preferred_element_type)
|
||||
result_types = [ir.RankedTensorType.get(result_shape, out_type)]
|
||||
|
||||
operands = [a, b, a_scales, b_scales]
|
||||
backend_config = {
|
||||
"scaled_dot_backend_config": {
|
||||
"lhs_batch_dimensions": [0],
|
||||
"rhs_batch_dimensions": [0],
|
||||
"dequantize_type": element_type_to_backend_config_type(out_type),
|
||||
}
|
||||
}
|
||||
|
||||
backend_config = json.dumps(backend_config)
|
||||
out = mlir.custom_call(
|
||||
block_scaled_dot_name,
|
||||
result_types=result_types,
|
||||
operands=operands,
|
||||
backend_config=backend_config,
|
||||
operand_layouts=default_layouts(
|
||||
*[ir.RankedTensorType(operand.type).shape for operand in operands]
|
||||
),
|
||||
result_layouts=default_layouts(result_shape),
|
||||
)
|
||||
return [out.result]
|
||||
|
||||
|
||||
def _scaled_matmul_abstract(a, b, a_scale, b_scale, *, preferred_element_type):
|
||||
a_dtype = dtypes.canonicalize_dtype(a.dtype)
|
||||
batch, non_contracting_lhs, contracting_lhs = a.shape
|
||||
_, non_contracting_rhs, _ = b.shape
|
||||
output_shape = (batch, non_contracting_lhs, non_contracting_rhs)
|
||||
return (core.ShapedArray(output_shape, preferred_element_type),)
|
||||
|
||||
|
||||
_scaled_matmul_p = core.Primitive("scaled_matmul")
|
||||
_scaled_matmul_p.multiple_results = True
|
||||
_scaled_matmul_p.def_impl(partial(xla.apply_primitive, _scaled_matmul_p))
|
||||
_scaled_matmul_p.def_abstract_eval(_scaled_matmul_abstract)
|
||||
|
||||
|
||||
mlir.register_lowering(
|
||||
_scaled_matmul_p,
|
||||
_scaled_matmul_cuda_lowering,
|
||||
platform="cuda",
|
||||
)
|
||||
|
||||
_scaled_matmul_p_wrapper = core.Primitive("scaled_matmul_wrapper")
|
||||
_scaled_matmul_p_wrapper.multiple_results = True
|
||||
_scaled_matmul_p_wrapper.def_impl(_scaled_matmul_impl)
|
||||
_scaled_matmul_p_wrapper.def_abstract_eval(_scaled_matmul_abstract)
|
||||
|
||||
# Given the inputs already sharded as
|
||||
# ([B], M, K1), ([B], N, K2)
|
||||
# We define the following rule to apply necessary AllGather based on
|
||||
# "Input specs", and to define the "Output spec".
|
||||
# 1. If K1 == K2 != None and N == None:
|
||||
# - Input spec : ([B], M, K1), ([B], None, K2)
|
||||
# - Output spec: ([B], M, None) -> AllReduce -> ([B], M, None)
|
||||
# 2. If K1 == K2 != None and M == N != None:
|
||||
# - Input spec : ([B], M, K1), ([B], None, K2)
|
||||
# - Output spec: ([B], M, None) -> ReduceScatter -> ([B], M, N)
|
||||
# 3. If N == M:
|
||||
# - Input specs : ([B], M, None), ([B], None, None)
|
||||
# - Output specs: ([B], M, None)
|
||||
# 4. If N != M:
|
||||
# - Input spec : ([B], M, None), ([B], N, None)
|
||||
# - Output spec: ([B], M, N)
|
||||
def _check_shardings(shardings):
|
||||
if len(shardings) != 4:
|
||||
msg = f"shardings should container 4 inputs, but got {len(shardings)}"
|
||||
raise TypeError(msg)
|
||||
lhs, rhs, _, _ = shardings
|
||||
if len(lhs.spec) != 3 or len(rhs.spec) != 3:
|
||||
msg = (f'shardings specs rank should be 3, but got lhs: {len(lhs.spec)} '
|
||||
'and rhs: {len(rhs.spec)}')
|
||||
raise TypeError(msg)
|
||||
if lhs.spec[0] != rhs.spec[0]:
|
||||
msg = ('shardings spec for batch dim should be same, but got lhs: '
|
||||
'{lhs.spec[0]} and rhs: {rhs.spec[0]}')
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
def _enable_reduce_scatter(lhs, rhs):
|
||||
batch_spec, m_spec, lhs_k_spec = lhs.spec
|
||||
_, n_spec, rhs_k_spec = rhs.spec
|
||||
return (
|
||||
lhs_k_spec != None
|
||||
and lhs_k_spec == rhs_k_spec
|
||||
and m_spec != None
|
||||
and m_spec == n_spec
|
||||
)
|
||||
|
||||
|
||||
def _enable_all_reduce(lhs, rhs):
|
||||
batch_spec, m_spec, lhs_k_spec = lhs.spec
|
||||
_, n_spec, rhs_k_spec = rhs.spec
|
||||
return lhs_k_spec != None and lhs_k_spec == rhs_k_spec and n_spec == None
|
||||
|
||||
|
||||
def _get_output_sharding(mesh, shardings):
|
||||
lhs, rhs = shardings[0], shardings[1]
|
||||
batch_spec, m_spec, _ = lhs.spec
|
||||
_, n_spec, _ = rhs.spec
|
||||
|
||||
if _enable_reduce_scatter(lhs, rhs):
|
||||
return [NamedSharding(lhs.mesh, P(*lhs.spec))]
|
||||
|
||||
output_specs = (batch_spec, m_spec)
|
||||
output_specs += (n_spec,) if m_spec != n_spec else (None,)
|
||||
return [NamedSharding(lhs.mesh, P(*output_specs))]
|
||||
|
||||
|
||||
def _scaled_matmul_infer_sharding_from_operands(
|
||||
preferred_element_type, mesh, shapes, output_shape
|
||||
):
|
||||
shardings = jax.tree.map(lambda x: x.sharding, shapes)
|
||||
_check_shardings(shardings)
|
||||
|
||||
return _get_output_sharding(mesh, shardings)
|
||||
|
||||
|
||||
def supported_in_sharding(mesh, shardings):
|
||||
lhs_sharding, rhs_sharding = shardings[0], shardings[1]
|
||||
use_reduce_scatter = _enable_reduce_scatter(lhs_sharding, rhs_sharding)
|
||||
use_all_reduce = _enable_all_reduce(lhs_sharding, rhs_sharding)
|
||||
assert not (use_all_reduce and use_reduce_scatter)
|
||||
|
||||
lhs_specs, rhs_specs = list(lhs_sharding.spec), list(rhs_sharding.spec)
|
||||
|
||||
def named_sharding(lhs, rhs, lhs_specs, rhs_specs):
|
||||
lhs_sharding = NamedSharding(lhs.mesh, P(*lhs_specs))
|
||||
rhs_sharding = NamedSharding(rhs.mesh, P(*rhs_specs))
|
||||
return (lhs_sharding, rhs_sharding, lhs_sharding, rhs_sharding)
|
||||
|
||||
if use_all_reduce:
|
||||
return named_sharding(lhs_sharding, rhs_sharding, lhs_specs, rhs_specs)
|
||||
|
||||
if use_reduce_scatter:
|
||||
rhs_specs[1] = None
|
||||
return named_sharding(lhs_sharding, rhs_sharding, lhs_specs, rhs_specs)
|
||||
|
||||
lhs_specs[2] = None
|
||||
rhs_specs[2] = None
|
||||
m_spec, n_spec = lhs_specs[1], rhs_specs[1]
|
||||
if m_spec == n_spec:
|
||||
rhs_specs[1] = None
|
||||
|
||||
return named_sharding(lhs_sharding, rhs_sharding, lhs_specs, rhs_specs)
|
||||
|
||||
|
||||
def _scaled_matmul_partition(
|
||||
preferred_element_type, mesh, shapes, output_shape
|
||||
):
|
||||
shardings = jax.tree.map(lambda x: x.sharding, shapes)
|
||||
_check_shardings(shardings)
|
||||
|
||||
lhs, rhs = shardings[0], shardings[1]
|
||||
use_all_reduce = _enable_all_reduce(lhs, rhs)
|
||||
use_reduce_scatter = _enable_reduce_scatter(lhs, rhs)
|
||||
lhs_k_spec = lhs.spec[2]
|
||||
|
||||
def _scaled_matmul_impl_partition(a, b, a_scale, b_scale):
|
||||
z = _scaled_matmul_impl(a, b, a_scale, b_scale, preferred_element_type)
|
||||
if use_reduce_scatter:
|
||||
z = jax.lax.psum_scatter(
|
||||
z, lhs_k_spec, scatter_dimension=2, tiled=True
|
||||
)
|
||||
if use_all_reduce:
|
||||
z = jax.lax.psum(z, lhs_k_spec)
|
||||
return z
|
||||
|
||||
out_shardings = _get_output_sharding(mesh, shardings)
|
||||
arg_shardings = supported_in_sharding(mesh, shardings)
|
||||
return mesh, _scaled_matmul_impl_partition, out_shardings, arg_shardings
|
||||
|
||||
|
||||
_scaled_matmul_lower = custom_partitioning(
|
||||
_scaled_matmul_impl, static_argnums=(4,)
|
||||
)
|
||||
|
||||
_scaled_matmul_lower.def_partition(
|
||||
infer_sharding_from_operands=_scaled_matmul_infer_sharding_from_operands,
|
||||
partition=_scaled_matmul_partition,
|
||||
)
|
||||
|
||||
|
||||
def _scaled_matmul_batcher(batched_args, batch_dims, *, preferred_element_type):
|
||||
assert len(batch_dims) == 4
|
||||
assert (
|
||||
batch_dims[0] == batch_dims[1]
|
||||
and batch_dims[0] == batch_dims[2]
|
||||
and batch_dims[0] == batch_dims[3]
|
||||
)
|
||||
lhs_bdims = batch_dims[0]
|
||||
out_bdims = (batch_dims[0],)
|
||||
lhs, rhs, lhs_scales, rhs_scales = batched_args
|
||||
*batch, lhs_non_contracting, contracting = lhs.shape
|
||||
*_, _, scales_contracting = lhs_scales.shape
|
||||
*_, rhs_non_contracting, _ = rhs.shape
|
||||
|
||||
new_batch = reduce(operator.mul, batch)
|
||||
# reshape to 3D shape
|
||||
lhs = jnp.reshape(lhs, (new_batch, lhs_non_contracting, contracting))
|
||||
lhs_scales = jnp.reshape(
|
||||
lhs_scales, (new_batch, lhs_non_contracting, scales_contracting)
|
||||
)
|
||||
rhs = jnp.reshape(rhs, (new_batch, rhs_non_contracting, contracting))
|
||||
rhs_scales = jnp.reshape(
|
||||
rhs_scales, (new_batch, rhs_non_contracting, scales_contracting)
|
||||
)
|
||||
output = jnp.reshape(
|
||||
_scaled_matmul_p_wrapper.bind(
|
||||
lhs,
|
||||
rhs,
|
||||
lhs_scales,
|
||||
rhs_scales,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)[0],
|
||||
(*batch, lhs_non_contracting, rhs_non_contracting),
|
||||
)
|
||||
return (output,), out_bdims
|
||||
|
||||
|
||||
mlir.register_lowering(
|
||||
_scaled_matmul_p_wrapper,
|
||||
mlir.lower_fun(_scaled_matmul_lower, multiple_results=True),
|
||||
)
|
||||
|
||||
dispatch.prim_requires_devices_during_lowering.add(_scaled_matmul_p)
|
||||
dispatch.prim_requires_devices_during_lowering.add(_scaled_matmul_p_wrapper)
|
||||
|
||||
batching.primitive_batchers[_scaled_matmul_p_wrapper] = _scaled_matmul_batcher
|
||||
batching.primitive_batchers[_scaled_matmul_p] = _scaled_matmul_batcher
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnames=("preferred_element_type",))
|
||||
def _scaled_matmul(
|
||||
lhs: Array,
|
||||
rhs: Array,
|
||||
lhs_scales: Array,
|
||||
rhs_scales: Array,
|
||||
preferred_element_type: DTypeLike = jnp.float32,
|
||||
) -> Array:
|
||||
output = _scaled_matmul_p_wrapper.bind(
|
||||
lhs, rhs, lhs_scales, rhs_scales,
|
||||
preferred_element_type=preferred_element_type
|
||||
)
|
||||
return output[0]
|
||||
|
||||
def scaled_matmul_wrapper(
|
||||
lhs: Array,
|
||||
rhs: Array,
|
||||
lhs_scales: Array,
|
||||
rhs_scales: Array,
|
||||
preferred_element_type: DTypeLike = jnp.float32,
|
||||
) -> Array:
|
||||
"""
|
||||
Performs scaled matrix multiplication between two 3D arrays, with scaling
|
||||
factors applied to the matrices.
|
||||
|
||||
Args:
|
||||
lhs (Array): A 3D array of shape (B, M, K).
|
||||
rhs (Array): A 3D array of shape (B, N, K).
|
||||
lhs_scales (Array): A 3D array of shape (B, M, K_block).
|
||||
rhs_scales (Array): A 3D array of shape (B, N, K_block).
|
||||
preferred_element_type (DTypeLike, optional): The preferred data type
|
||||
for the computation. Defaults to `jnp.float32`.
|
||||
|
||||
Returns:
|
||||
Array: A 3D array of shape (B, M, N) representing the scaled matrix
|
||||
multiplication result.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the number of columns in `lhs` (`lhs_K`) does not
|
||||
match the number of columns in `rhs` (`rhs_K`).
|
||||
|
||||
Notes:
|
||||
- The function ensures that the `preferred_element_type` is
|
||||
danonicalized before passing it to the underlying computation.
|
||||
- Scaling is applied to the matrices based on the `lhs_scales` and
|
||||
`rhs_scales` arrays, enabling efficient computations in blocks.
|
||||
|
||||
"""
|
||||
B, M, lhs_K = lhs.shape
|
||||
_, N, rhs_K = rhs.shape
|
||||
assert lhs_K == rhs_K
|
||||
_, _, K_block = lhs_scales.shape
|
||||
|
||||
preferred_element_type = dtypes.canonicalize_dtype(
|
||||
np.dtype(preferred_element_type)
|
||||
)
|
||||
out = _scaled_matmul(
|
||||
lhs,
|
||||
rhs,
|
||||
lhs_scales,
|
||||
rhs_scales,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
return out
|
||||
|
||||
def shape_normalization(x, dimension_numbers):
|
||||
"""
|
||||
Normalizes the shape of the input tensor `x` to `(B, M, K)`.
|
||||
|
||||
This function rearranges and reshapes the input tensor `x` such that:
|
||||
- `B` represents the batch dimensions.
|
||||
- `M` represents the non-contracting dimensions.
|
||||
- `K` represents the contracting dimensions.
|
||||
|
||||
The dimensions are reordered and reshaped based on the provided
|
||||
`dimension_numbers`.
|
||||
|
||||
Parameters:
|
||||
x: The input tensor to normalize.
|
||||
dimension_numbers: A tuple containing two elements:
|
||||
- `batch_dims` (tuple): The dimensions of `x` to be treated as batch
|
||||
dimensions.
|
||||
- `contracting_dims` (tuple): The dimensions of `x` to be treated as
|
||||
contracting dimensions.
|
||||
|
||||
Returns:
|
||||
jax.numpy.ndarray: The reshaped tensor with shape `(B, M, K)`
|
||||
"""
|
||||
|
||||
orig_order = list(range(x.ndim))
|
||||
contracting_dims, batch_dims = dimension_numbers
|
||||
contracting_order = [d for d in orig_order if d in contracting_dims]
|
||||
batch_order = [d for d in orig_order if d in batch_dims]
|
||||
non_contracting_order = [
|
||||
d
|
||||
for d in orig_order
|
||||
if d not in contracting_dims and d not in batch_dims
|
||||
]
|
||||
batch_shape = [x.shape[d] for d in batch_order]
|
||||
rows_shape = [x.shape[d] for d in non_contracting_order]
|
||||
cols_shape = [x.shape[d] for d in contracting_order]
|
||||
new_order = batch_order + non_contracting_order + contracting_order
|
||||
rows, cols, batches = (
|
||||
np.prod(rows_shape),
|
||||
np.prod(cols_shape),
|
||||
np.prod(batch_shape, dtype=int),
|
||||
)
|
||||
t = jnp.transpose(x, new_order)
|
||||
return jnp.reshape(t, (batches, rows, cols))
|
||||
|
||||
|
||||
def compute_dot_output_shape(
|
||||
lhs_shape, rhs_shape, lhs_dimension_numbers, rhs_dimension_numbers
|
||||
):
|
||||
"""
|
||||
Computes the output shape for a `lax.dot_general`-like operation.
|
||||
"""
|
||||
lhs_contract, lhs_batch = lhs_dimension_numbers[0], lhs_dimension_numbers[1]
|
||||
rhs_contract, rhs_batch = rhs_dimension_numbers[0], rhs_dimension_numbers[1]
|
||||
|
||||
output_shape = []
|
||||
# Add dimensions for batch (assuming the batch dims of LHS and RHS
|
||||
# should be same)
|
||||
for i, dim in enumerate(lhs_shape):
|
||||
if i in lhs_batch:
|
||||
output_shape.append(dim)
|
||||
# Add dimensions from the LHS that are non contracting
|
||||
for i, dim in enumerate(lhs_shape):
|
||||
if i not in lhs_contract and i not in lhs_batch:
|
||||
output_shape.append(dim)
|
||||
# Add dimensions from the RHS that are non contracting
|
||||
for i, dim in enumerate(rhs_shape):
|
||||
if i not in rhs_contract and i not in rhs_batch:
|
||||
output_shape.append(dim)
|
||||
return tuple(output_shape)
|
||||
|
||||
|
||||
def cast_to_e8m0_with_rounding_up(x):
|
||||
temp = x.astype(jnp.float32).view(jnp.uint32)
|
||||
exp = temp >> 23
|
||||
mant = temp & 0x7FFFFF
|
||||
is_ru = jnp.logical_and(
|
||||
jnp.logical_and((mant > 0), (exp != 0xFE)),
|
||||
~jnp.logical_and((exp == 0), (mant <= 0x400000))
|
||||
)
|
||||
exp = jnp.where(is_ru, exp + 1, exp)
|
||||
new_x = exp.astype(jnp.uint8)
|
||||
return new_x
|
||||
|
||||
|
||||
def e8m0_to_dtype(x, dtype):
|
||||
temp = x.astype(jnp.uint32)
|
||||
exp = temp << 23
|
||||
new_x = exp.view(jnp.float32)
|
||||
near_zero_value = 2**-15 if dtype == jnp.float16 else 2**-127
|
||||
new_x = jnp.where(
|
||||
new_x == 0, jnp.array(near_zero_value, jnp.float32), new_x
|
||||
)
|
||||
return new_x.astype(dtype)
|
||||
|
||||
def quantize(x, config):
|
||||
x_shape = x.shape
|
||||
contract_dim = x_shape[-1]
|
||||
block_size = config.block_size
|
||||
assert contract_dim >= block_size and contract_dim % block_size == 0
|
||||
x_new_shape = x_shape[:-1] + (x_shape[-1] // block_size, block_size)
|
||||
x = x.reshape(x_new_shape) # shape = (B, M, K / block_size, block_size)
|
||||
|
||||
amax = jnp.max(jnp.abs(x), axis=-1, keepdims=True)
|
||||
MAX = jnp.finfo(config.data_type).max.astype(x.dtype)
|
||||
scales = amax / MAX # shape = (B, M, K / block_size, 1)
|
||||
|
||||
if config.mode == "mxfp8":
|
||||
assert config.scale_type == jnp.float8_e8m0fnu
|
||||
scales_q = cast_to_e8m0_with_rounding_up(scales)
|
||||
scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype)
|
||||
elif config.mode == "nvfp4":
|
||||
assert config.scale_type == jnp.float8_e4m3fn
|
||||
# shuw(TODO): Add when XLA is ready and e2m1fn is available.
|
||||
scales_q = scales
|
||||
scales_x = x
|
||||
else:
|
||||
raise ValueError(f"Unrecognized mode: {config.mode}.")
|
||||
|
||||
clipped_x = jnp.clip(scaled_x, -MAX, MAX)
|
||||
x_q = clipped_x.astype(config.data_type)
|
||||
|
||||
x_q = x_q.reshape(x_shape) # shape = (B, M, K)
|
||||
scales_q = jnp.reshape(scales_q, scales_q.shape[:-1]).view(
|
||||
config.scale_type
|
||||
)
|
||||
return x_q, scales_q
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
|
||||
configs):
|
||||
if preferred_element_type is None:
|
||||
preferred_element_type = dtypes.result_type(
|
||||
lhs, rhs, return_weak_type_flag=False
|
||||
)
|
||||
else:
|
||||
preferred_element_type = dtypes.canonicalize_dtype(
|
||||
np.dtype(preferred_element_type)
|
||||
)
|
||||
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
lhs_dn = (lhs_contract, lhs_batch)
|
||||
rhs_dn = (rhs_contract, rhs_batch)
|
||||
|
||||
lhs_3d = shape_normalization(lhs, lhs_dn)
|
||||
rhs_3d = shape_normalization(rhs, rhs_dn)
|
||||
lhs_config, rhs_config = configs[0], configs[1]
|
||||
lhs_q, lhs_scales = quantize(lhs_3d, lhs_config)
|
||||
rhs_q, rhs_scales = quantize(rhs_3d, rhs_config)
|
||||
|
||||
out = scaled_matmul_wrapper(
|
||||
lhs_q, rhs_q, lhs_scales, rhs_scales, preferred_element_type
|
||||
)
|
||||
|
||||
expanded_out_shape = compute_dot_output_shape(
|
||||
lhs.shape, rhs.shape, lhs_dn, rhs_dn
|
||||
)
|
||||
expanded_out = jnp.reshape(out, expanded_out_shape)
|
||||
return expanded_out
|
||||
|
||||
|
||||
def scaled_dot_general_transpose_lhs(
|
||||
g, x, y, *, dimension_numbers, preferred_element_type, configs,
|
||||
swap_ans=False
|
||||
):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
x_ndim = x.aval.ndim
|
||||
x_kept = remaining(range(x_ndim), x_contract, x_batch)
|
||||
y_kept = remaining(range(np.ndim(y)), y_contract, y_batch)
|
||||
if swap_ans:
|
||||
ans_batch, ans_y, _ = ranges_like(x_batch, y_kept, x_kept)
|
||||
else:
|
||||
ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
|
||||
|
||||
dims = ((ans_y, y_kept), (ans_batch, y_batch))
|
||||
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
|
||||
out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
|
||||
|
||||
y_dn = (y_kept, y_batch)
|
||||
g_dn = (ans_y, ans_batch)
|
||||
|
||||
y_3d = shape_normalization(y, y_dn)
|
||||
g_3d = shape_normalization(g, g_dn)
|
||||
|
||||
g_config, y_config = configs[0], configs[1]
|
||||
|
||||
g_q, g_scales = quantize(g_3d, g_config)
|
||||
y_q, y_scales = quantize(y_3d, y_config)
|
||||
|
||||
out = scaled_matmul_wrapper(
|
||||
g_q, y_q, g_scales, y_scales, preferred_element_type
|
||||
)
|
||||
|
||||
expanded_out_shape = compute_dot_output_shape(g.shape, y.shape, g_dn, y_dn)
|
||||
expanded_out = jnp.reshape(out, expanded_out_shape)
|
||||
x_bar = lax.transpose(expanded_out, tuple(out_axes))
|
||||
return x_bar
|
||||
|
||||
|
||||
def scaled_dot_general_transpose_rhs(
|
||||
g, x, y, *, dimension_numbers, preferred_element_type: DTypeLike,
|
||||
configs: List[BlockScaleConfig]
|
||||
):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
|
||||
y_bar = scaled_dot_general_transpose_lhs(
|
||||
g,
|
||||
y,
|
||||
x,
|
||||
dimension_numbers=swapped_dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
configs=configs,
|
||||
swap_ans=True,
|
||||
)
|
||||
return y_bar
|
||||
|
||||
|
||||
@partial(custom_vjp, nondiff_argnums=(2, 3, 4))
|
||||
def scaled_dot_general_fn(lhs, rhs, dimension_numbers, preferred_element_type,
|
||||
configs):
|
||||
return scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
|
||||
configs)
|
||||
|
||||
|
||||
def scaled_dot_fwd(lhs, rhs, dimension_numbers, preferred_element_type,
|
||||
configs):
|
||||
out = scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
|
||||
configs)
|
||||
res = (lhs, rhs)
|
||||
return out, res
|
||||
|
||||
|
||||
def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g):
|
||||
(lhs, rhs) = res
|
||||
|
||||
args = [g, lhs, rhs]
|
||||
kw_args = {
|
||||
"dimension_numbers": dimension_numbers,
|
||||
"preferred_element_type": preferred_element_type,
|
||||
}
|
||||
lhs_kw_args = {
|
||||
**kw_args,
|
||||
"configs": [configs[2], configs[1]]
|
||||
}
|
||||
rhs_kw_args = {
|
||||
**kw_args,
|
||||
"configs": [configs[2], configs[0]]
|
||||
}
|
||||
grad_lhs = scaled_dot_general_transpose_lhs(*args, **lhs_kw_args)
|
||||
grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args)
|
||||
return (grad_lhs, grad_rhs)
|
||||
|
||||
|
||||
scaled_dot_general_fn.defvjp(scaled_dot_fwd, scaled_dot_bwd)
|
||||
|
||||
|
||||
def ensure_tuple(dimension_numbers):
|
||||
_to_tuple = lambda x: x if isinstance(x, tuple) else tuple(x)
|
||||
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
lhs_contract = _to_tuple(lhs_contract)
|
||||
rhs_contract = _to_tuple(rhs_contract)
|
||||
lhs_batch = _to_tuple(lhs_batch)
|
||||
rhs_batch = _to_tuple(rhs_batch)
|
||||
return (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
|
||||
|
||||
|
||||
def _ensure_batch_dim(lhs, rhs, dimension_numbers):
|
||||
contracting_dims, (lhs_batch, rhs_batch) = dimension_numbers
|
||||
lhs_batched = lhs
|
||||
rhs_batched = rhs
|
||||
|
||||
if lhs_batch == (): # expand the last dim
|
||||
lhs_batched = jnp.expand_dims(lhs, axis=lhs.aval.ndim)
|
||||
lhs_batch = (lhs.aval.ndim,)
|
||||
if rhs_batch == ():
|
||||
rhs_batched = jnp.expand_dims(rhs, axis=rhs.aval.ndim)
|
||||
rhs_batch = (rhs.aval.ndim,)
|
||||
dn_batched = contracting_dims, (lhs_batch, rhs_batch)
|
||||
return lhs_batched, rhs_batched, dn_batched
|
||||
|
||||
|
||||
def scaled_dot_general_wrapper(
|
||||
lhs, rhs, dimension_numbers,
|
||||
preferred_element_type=jnp.float32,
|
||||
configs: List[BlockScaleConfig] | None=None,
|
||||
):
|
||||
if preferred_element_type not in (jnp.float32, jnp.bfloat16, jnp.float16):
|
||||
msg = ('Only support preferred_element_type in (f32, bf16, f16), but got '
|
||||
'{preferred_element_type}')
|
||||
raise TypeError(msg)
|
||||
if configs is None:
|
||||
mxfp8_config = BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
configs = [mxfp8_config, mxfp8_config, mxfp8_config]
|
||||
|
||||
dimension_numbers = ensure_tuple(dimension_numbers)
|
||||
lhs_batched, rhs_batched, dn_batched = _ensure_batch_dim(
|
||||
lhs, rhs, dimension_numbers
|
||||
)
|
||||
out = scaled_dot_general_fn(
|
||||
lhs_batched, rhs_batched, dn_batched, preferred_element_type, configs,
|
||||
)
|
||||
|
||||
# Expanding batch dims for operands adds a singleton batch dim at axis 0 in
|
||||
# the output, which we need to squeeze.
|
||||
if dn_batched != dimension_numbers:
|
||||
return jnp.squeeze(out, axis=0)
|
||||
return out
|
@ -21,7 +21,7 @@ from functools import partial
|
||||
import operator
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import Any, Literal
|
||||
from typing import Any, List, Literal
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -36,10 +36,14 @@ from jax._src.core import AxisName
|
||||
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P
|
||||
from jax._src.cudnn.fused_attention_stablehlo import (
|
||||
dot_product_attention as cudnn_dot_product_attention, MaskType)
|
||||
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||||
scaled_matmul_wrapper as cudnn_scaled_matmul,
|
||||
scaled_dot_general_wrapper as cudnn_scaled_dot_general,
|
||||
BlockScaleConfig)
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.numpy import util as numpy_util
|
||||
from jax._src.typing import Array, ArrayLike, DType
|
||||
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
|
||||
from jax._src.ops.special import logsumexp as _logsumexp
|
||||
|
||||
|
||||
@ -1185,3 +1189,109 @@ def dot_product_attention(
|
||||
raise ValueError(f"Unsupported implementation option: {implementation}")
|
||||
|
||||
return jnp.reshape(out, output_shape)
|
||||
|
||||
def scaled_matmul(
|
||||
lhs: Array,
|
||||
rhs: Array,
|
||||
lhs_scales: Array,
|
||||
rhs_scales: Array,
|
||||
preferred_element_type: DTypeLike = jnp.float32,
|
||||
) -> Array:
|
||||
r"""
|
||||
Performs scaled matrix multiplication between two 3D arrays, with scaling
|
||||
factors applied to the matrices.
|
||||
.. math::
|
||||
\mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales \cdot rhs_scales \cdot \mathrm{dot}(lhs, rhs)
|
||||
Args:
|
||||
lhs (Array): A 3D array of shape (B, M, K).
|
||||
rhs (Array): A 3D array of shape (B, N, K).
|
||||
lhs_scales (Array): A 3D array of shape (B, M, K_block).
|
||||
rhs_scales (Array): A 3D array of shape (B, N, K_block).
|
||||
preferred_element_type (DTypeLike, optional): The preferred data type
|
||||
for the computation. Defaults to `jnp.float32`.
|
||||
Returns:
|
||||
Array: A 3D array of shape (B, M, N) representing the scaled matrix
|
||||
multiplication result.
|
||||
Raises:
|
||||
AssertionError: If the number of columns in `lhs` (`lhs_K`) does not
|
||||
match the number of columns in `rhs` (`rhs_K`).
|
||||
Notes:
|
||||
- The function ensures that the `preferred_element_type` is
|
||||
danonicalized before passing it to the underlying computation.
|
||||
- Scaling is applied to the matrices based on the `lhs_scales` and
|
||||
`rhs_scales` arrays, enabling efficient computations in blocks.
|
||||
"""
|
||||
B, M, lhs_K = lhs.shape
|
||||
_, N, rhs_K = rhs.shape
|
||||
assert lhs_K == rhs_K
|
||||
_, _, K_block = lhs_scales.shape
|
||||
|
||||
preferred_element_type = dtypes.canonicalize_dtype(
|
||||
np.dtype(preferred_element_type)
|
||||
)
|
||||
out = cudnn_scaled_matmul(
|
||||
lhs,
|
||||
rhs,
|
||||
lhs_scales,
|
||||
rhs_scales,
|
||||
preferred_element_type=preferred_element_type,
|
||||
)
|
||||
return out
|
||||
|
||||
def scaled_dot_general(
|
||||
lhs, rhs,
|
||||
dimension_numbers,
|
||||
preferred_element_type=jnp.float32,
|
||||
configs: List[BlockScaleConfig] | None = None,
|
||||
implementation: Literal['cudnn'] | None = None,
|
||||
):
|
||||
r"""Scaled dot general operation.
|
||||
Computes the scaled dot general on lhs, rhs with quanitzation specified by configs:
|
||||
.. math::
|
||||
\widehat{lhs}, s_a=\mathrm{quantize}(lhs) \\
|
||||
\widehat{rhs}, s_b=\mathrm{quantize}(rhs) \\
|
||||
\mathrm{ScaledDot}(lhs, rhs)=s_a \cdot s_b \cdot \mathrm{dot}(\widehat{lhs}, \widehat{rhs})
|
||||
Args:
|
||||
lhs: Left-hand side input tensor.
|
||||
rhs: Right-hand side input tensor.
|
||||
dimension_numbers: A tuple specifying the contraction and batch dimensions
|
||||
for the dot general operation. Must follow the format:
|
||||
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
|
||||
preferred_element_type: The preferred output data type. Supported types are
|
||||
`jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`.
|
||||
configs: A list of `BlockScaleConfig` specifying the scaling
|
||||
configurations for the operation. Defaults to `mxfp8`.
|
||||
implementation: A string to control which implementation backend to use.
|
||||
Supported strings are `cudnn` (cuDNN block scaled dot). It defaults
|
||||
to `None`, which will automatically select the best available backend.
|
||||
Returns:
|
||||
The result of the scaled dot general operation.
|
||||
"""
|
||||
# Create configs if not provided
|
||||
if configs is None:
|
||||
if dtypes.float8_e8m0fnu is None:
|
||||
raise ValueError("Requires >= ml_dtypes 0.5.0 to support float8_e8m0fnu")
|
||||
mxfp8_config = BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
configs = [mxfp8_config for _ in range(3)]
|
||||
|
||||
if implementation is None:
|
||||
implementation = 'cudnn'
|
||||
|
||||
match implementation:
|
||||
case 'cudnn':
|
||||
out = cudnn_scaled_dot_general(
|
||||
lhs, rhs, dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
configs=configs
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported implementation option: {implementation}")
|
||||
|
||||
return out
|
||||
|
@ -206,6 +206,8 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
"dot_product_attention_fp8_bwd_wrapper",
|
||||
):
|
||||
continue
|
||||
if p.name == "scaled_matmul_wrapper":
|
||||
continue
|
||||
if p.name in tf_not_yet_impl:
|
||||
self.assertNotIn(
|
||||
p, tf_impl) # Should not be in both tf_impl and tf_not_yet_impl
|
||||
|
@ -37,6 +37,8 @@ from jax._src.nn.functions import (
|
||||
relu as relu,
|
||||
relu6 as relu6,
|
||||
dot_product_attention as dot_product_attention,
|
||||
scaled_dot_general as scaled_dot_general,
|
||||
scaled_matmul as scaled_matmul,
|
||||
selu as selu,
|
||||
sigmoid as sigmoid,
|
||||
soft_sign as soft_sign,
|
||||
|
@ -1656,6 +1656,15 @@ jax_multiplatform_test(
|
||||
tags = ["multiaccelerator"],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "scaled_matmul_stablehlo_test",
|
||||
srcs = ["scaled_matmul_stablehlo_test.py"],
|
||||
enable_backends = ["gpu"],
|
||||
shard_count = {
|
||||
"gpu": 4,
|
||||
},
|
||||
)
|
||||
|
||||
jax_py_test(
|
||||
name = "custom_partitioning_sharding_rule_test",
|
||||
srcs = ["custom_partitioning_sharding_rule_test.py"],
|
||||
|
165
tests/nn_test.py
165
tests/nn_test.py
@ -27,8 +27,14 @@ import scipy.stats
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||||
quantize,
|
||||
shape_normalization,
|
||||
BlockScaleConfig,
|
||||
)
|
||||
from jax.test_util import check_grads
|
||||
from jax import nn
|
||||
from jax import random
|
||||
@ -37,9 +43,9 @@ import jax.numpy as jnp
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
def _is_required_cudnn_version_satisfied(min_cudnn_version):
|
||||
def _is_required_cudnn_version_satisfied(min_cc, min_cudnn_version):
|
||||
return (
|
||||
jtu.is_cuda_compute_capability_at_least("8.0") and
|
||||
jtu.is_cuda_compute_capability_at_least(min_cc) and
|
||||
cuda_versions is not None and
|
||||
cuda_versions.cudnn_get_version() >= min_cudnn_version
|
||||
)
|
||||
@ -51,9 +57,158 @@ def _check_cudnn_backend(fn, *args, **kwargs):
|
||||
|
||||
_cudnn_dbias_error = 'cuDNN only supports bias gradient'
|
||||
|
||||
def quantize_to_qtype(x, q_dtype, compute_dtype, scale):
|
||||
# Explicitly cast the max values to the compute dtype to avoid unnecessary
|
||||
# casting to FP32 during the subsequent math operations."
|
||||
assert q_dtype in (jnp.float8_e4m3fn, )
|
||||
dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype)
|
||||
scaled_x = x / jnp.broadcast_to(
|
||||
jnp.asarray(scale, dtype=compute_dtype), x.shape
|
||||
)
|
||||
clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
|
||||
return clipped_x.astype(q_dtype)
|
||||
|
||||
def quantize_dequantize(x, q_dtype, scale, compute_dtype):
|
||||
qx = quantize_to_qtype(x, q_dtype, compute_dtype, scale)
|
||||
out = qx.astype(x.dtype) * jnp.broadcast_to(
|
||||
jnp.asarray(scale, dtype=x.dtype), qx.shape
|
||||
)
|
||||
return out
|
||||
|
||||
def _generate_quantized_tensors(
|
||||
batch, lhs_non_contract, contract, rhs_non_contract,
|
||||
configs, dtype=jnp.float32,
|
||||
):
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize,
|
||||
scale=jnp.ones((1,)),
|
||||
compute_dtype=dtype,
|
||||
)
|
||||
|
||||
k1, k2 = jax.random.split(jax.random.key(123), 2)
|
||||
|
||||
a = cast_to_representable(
|
||||
jax.random.uniform(
|
||||
k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype
|
||||
),
|
||||
configs[0].data_type,
|
||||
)
|
||||
b = cast_to_representable(
|
||||
jax.random.uniform(
|
||||
k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype
|
||||
),
|
||||
configs[1].data_type,
|
||||
)
|
||||
|
||||
dn = ((2,), (0,))
|
||||
a_3d = shape_normalization(a, dn)
|
||||
b_3d = shape_normalization(b, dn)
|
||||
a_q, a_scales = quantize(a, configs[0])
|
||||
b_q, b_scales = quantize(b, configs[1])
|
||||
|
||||
return a, b, a_q, b_q, a_scales, b_scales
|
||||
|
||||
def create_mxfp8_configs_if_available():
|
||||
if _dtypes.float8_e8m0fnu is None:
|
||||
raise unittest.SkipTest("float8_e8m0fnu is not available.")
|
||||
|
||||
def _create_mxfp8_config():
|
||||
return BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
|
||||
return [_create_mxfp8_config() for _ in range(3)]
|
||||
|
||||
|
||||
@jtu.with_config(jax_legacy_prng_key="allow",
|
||||
jax_numpy_dtype_promotion="standard")
|
||||
class NNFunctionsTest(jtu.JaxTestCase):
|
||||
@parameterized.product(
|
||||
contract=[160, 96],
|
||||
lhs_non_contract=[240, 100],
|
||||
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
|
||||
impl=['cudnn',],
|
||||
)
|
||||
def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
|
||||
# Check if float8_e8m0fnu is available
|
||||
configs = create_mxfp8_configs_if_available()
|
||||
batch, rhs_non_contract = 4, 256
|
||||
a, b, a_q, b_q, a_scales, b_scales = _generate_quantized_tensors(
|
||||
batch, lhs_non_contract, contract, rhs_non_contract,
|
||||
configs, dtype=dtype,
|
||||
)
|
||||
out = nn.scaled_matmul(a_q, b_q, a_scales, b_scales,
|
||||
preferred_element_type=dtype)
|
||||
out_ref = jnp.matmul(a.astype(jnp.float32),
|
||||
jnp.transpose(b, (0, 2, 1)).astype(jnp.float32))
|
||||
self.assertArraysAllClose(
|
||||
out, out_ref.astype(dtype), rtol=1e-3, atol=1e-3
|
||||
)
|
||||
|
||||
@parameterized.product(
|
||||
is_training=[True, False],
|
||||
output_type=[jnp.float16, jnp.bfloat16, jnp.float32],
|
||||
impl=['cudnn',],
|
||||
)
|
||||
def testScaledDotGeneral(
|
||||
self, is_training, output_type, impl):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
|
||||
|
||||
configs = create_mxfp8_configs_if_available()
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize,
|
||||
scale=jnp.ones((1,)),
|
||||
compute_dtype=jnp.float32,
|
||||
)
|
||||
k1, k2 = jax.random.split(jax.random.key(0), 2)
|
||||
a_shape = [2, 256, 96]
|
||||
b_shape = [2, 96, 160]
|
||||
dimension_numbers = (([2], [1]), ([0], [0]))
|
||||
a = cast_to_representable(
|
||||
jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type),
|
||||
configs[0].data_type,
|
||||
)
|
||||
b = cast_to_representable(
|
||||
jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type),
|
||||
configs[1].data_type,
|
||||
)
|
||||
|
||||
scaled_dot_general_fn = partial(
|
||||
nn.scaled_dot_general, configs=configs
|
||||
)
|
||||
def fwd(a, b, is_ref=False):
|
||||
fn = jax.lax.dot_general if is_ref else scaled_dot_general_fn
|
||||
y = fn(a, b, dimension_numbers,
|
||||
preferred_element_type=output_type)
|
||||
return jnp.sum(y)
|
||||
|
||||
if is_training:
|
||||
j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1]))
|
||||
|
||||
j_train_ref = jax.jit(
|
||||
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1])
|
||||
)
|
||||
out, (x_grad, w_grad) = j_train(a, b)
|
||||
out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b)
|
||||
|
||||
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
||||
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
|
||||
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
|
||||
else:
|
||||
j_inference = jax.jit(fwd)
|
||||
j_inference_ref = jax.jit(partial(fwd, is_ref=True))
|
||||
out = j_inference(a, b)
|
||||
out_ref = j_inference_ref(a, b)
|
||||
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@parameterized.product(
|
||||
dtype=[jnp.bfloat16, jnp.float16],
|
||||
group_num=[1, 2, 4],
|
||||
@ -61,7 +216,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
impl=['cudnn', 'xla'],
|
||||
)
|
||||
def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(8904):
|
||||
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("8.0", 8904):
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
||||
if impl == 'cudnn' and dtype == jnp.float32:
|
||||
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
|
||||
@ -110,7 +265,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
if isinstance(mask_mode, str):
|
||||
mask_mode = (mask_mode,)
|
||||
min_cudnn_version = 90200 if 'sliding_window' in mask_mode else 8904
|
||||
if not _is_required_cudnn_version_satisfied(min_cudnn_version):
|
||||
if not _is_required_cudnn_version_satisfied("8.0", min_cudnn_version):
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
||||
|
||||
dtype = jnp.bfloat16
|
||||
@ -173,7 +328,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
|
||||
use_vmap=[False, True],
|
||||
)
|
||||
def testDotProductAttentionBiasGradient(self, batch_size, use_vmap):
|
||||
if not _is_required_cudnn_version_satisfied(8904):
|
||||
if not _is_required_cudnn_version_satisfied("8.0", 8904):
|
||||
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
|
||||
|
||||
dtype = jnp.bfloat16
|
||||
|
536
tests/scaled_matmul_stablehlo_test.py
Normal file
536
tests/scaled_matmul_stablehlo_test.py
Normal file
@ -0,0 +1,536 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from absl.testing import absltest
|
||||
|
||||
import re
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec, NamedSharding
|
||||
from jax._src import config
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version
|
||||
from jax._src.cudnn.scaled_matmul_stablehlo import (
|
||||
scaled_matmul_wrapper,
|
||||
scaled_dot_general_wrapper,
|
||||
shape_normalization,
|
||||
quantize,
|
||||
BlockScaleConfig,
|
||||
)
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
input_shardings = [
|
||||
(("dp", None, "tp"), ("dp", None, "tp")),
|
||||
(("dp", None, "tp"), ("dp", None, None)),
|
||||
(("dp", None, "tp"), ("dp", "tp", None)),
|
||||
(("dp", None, None), ("dp", "tp", None)),
|
||||
(("dp", "tp", None), ("dp", "tp", None)),
|
||||
((None, "dp", "tp"), (None, "dp", "tp")),
|
||||
((None, "tp", None), (None, "tp", None)),
|
||||
((None, None, "tp"), (None, "tp", None)),
|
||||
]
|
||||
c_name = "__cudnn$blockScaledDot"
|
||||
expected_hlos = [
|
||||
(c_name, "all-reduce", "f32[1,512,512]", "replica_groups={{0,1},{2,3}}"),
|
||||
("all-gather", "f8e4m3fn[1,512,512]", "replica_groups=[2,2]<=[4]", c_name),
|
||||
("all-gather", "f8e4m3fn[1,512,512]", "replica_groups=[2,2]<=[4]", c_name),
|
||||
(c_name,),
|
||||
("all-gather", "f8e4m3fn[1,256,1024]", "replica_groups=[2,2]<=[4]", c_name),
|
||||
(c_name, "reduce-scatter", "f32[2,256,512]", "replica_groups={{0,1},{2,3}}"),
|
||||
("all-gather", "f8e4m3fn[2,512,1024]", "replica_groups=[2,2]<=[4]", c_name),
|
||||
("all-gather", "f8e4m3fn[2,512,512]", "replica_groups=[2,2]<=[4]", c_name),
|
||||
]
|
||||
expected_output_spec = [
|
||||
PartitionSpec('dp',),
|
||||
PartitionSpec('dp',),
|
||||
PartitionSpec('dp', None, 'tp'),
|
||||
PartitionSpec('dp', None, 'tp'),
|
||||
PartitionSpec('dp', 'tp', None),
|
||||
PartitionSpec(None, 'dp', 'tp'),
|
||||
PartitionSpec(None, 'tp', None),
|
||||
PartitionSpec(None, None, 'tp'),
|
||||
]
|
||||
sharding_configs = {
|
||||
input_sharding: (hlo, output_spec)
|
||||
for input_sharding, hlo, output_spec in zip(input_shardings, expected_hlos, expected_output_spec)
|
||||
}
|
||||
|
||||
def quantize_to_qtype(x, q_dtype, compute_dtype, scale):
|
||||
# Explicitly cast the max values to the compute dtype to avoid unnecessary
|
||||
# casting to FP32 during the subsequent math operations."
|
||||
assert q_dtype in (jnp.float8_e4m3fn, )
|
||||
dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype)
|
||||
scaled_x = x / jnp.broadcast_to(
|
||||
jnp.asarray(scale, dtype=compute_dtype), x.shape
|
||||
)
|
||||
clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
|
||||
return clipped_x.astype(q_dtype)
|
||||
|
||||
def quantize_dequantize(x, q_dtype, scale, compute_dtype):
|
||||
qx = quantize_to_qtype(x, q_dtype, compute_dtype, scale)
|
||||
out = qx.astype(x.dtype) * jnp.broadcast_to(
|
||||
jnp.asarray(scale, dtype=x.dtype), qx.shape
|
||||
)
|
||||
return out
|
||||
|
||||
def generate_quantized_tensors(
|
||||
batch, lhs_non_contract, contract, rhs_non_contract,
|
||||
configs, dtype=jnp.float32,
|
||||
):
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize,
|
||||
scale=jnp.ones((1,)),
|
||||
compute_dtype=dtype,
|
||||
)
|
||||
|
||||
k1, k2 = jax.random.split(jax.random.key(123), 2)
|
||||
|
||||
a = cast_to_representable(
|
||||
jax.random.uniform(
|
||||
k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype
|
||||
),
|
||||
configs[0].data_type,
|
||||
)
|
||||
b = cast_to_representable(
|
||||
jax.random.uniform(
|
||||
k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype
|
||||
),
|
||||
configs[1].data_type,
|
||||
)
|
||||
|
||||
dn = ((2,), (0,))
|
||||
a_3d = shape_normalization(a, dn)
|
||||
b_3d = shape_normalization(b, dn)
|
||||
a_q, a_scales = quantize(a, configs[0])
|
||||
b_q, b_scales = quantize(b, configs[1])
|
||||
|
||||
return a, b, a_q, b_q, a_scales, b_scales
|
||||
|
||||
|
||||
def shard_and_device_put(
|
||||
mesh, a_sharding, b_sharding, a, b, a_scales=None, b_scales=None
|
||||
):
|
||||
a_spec = PartitionSpec(*a_sharding)
|
||||
b_spec = PartitionSpec(*b_sharding)
|
||||
|
||||
a_named_sharding = NamedSharding(mesh, a_spec)
|
||||
b_named_sharding = NamedSharding(mesh, b_spec)
|
||||
|
||||
a = jax.device_put(a, a_named_sharding)
|
||||
b = jax.device_put(b, b_named_sharding)
|
||||
if a_scales is not None:
|
||||
a_scales = jax.device_put(a_scales, a_named_sharding)
|
||||
if b_scales is not None:
|
||||
b_scales = jax.device_put(b_scales, b_named_sharding)
|
||||
|
||||
in_shardings = (
|
||||
a_named_sharding,
|
||||
b_named_sharding,
|
||||
)
|
||||
if a_scales is not None and b_scales is not None:
|
||||
in_shardings = (
|
||||
a_named_sharding,
|
||||
b_named_sharding,
|
||||
a_named_sharding,
|
||||
b_named_sharding,
|
||||
)
|
||||
return a, b, a_scales, b_scales, in_shardings
|
||||
|
||||
return a, b, in_shardings
|
||||
|
||||
def create_mxfp8_configs():
|
||||
if _dtypes.float8_e8m0fnu is None:
|
||||
return None
|
||||
|
||||
mxfp8_config = BlockScaleConfig(
|
||||
mode='mxfp8',
|
||||
block_size=32,
|
||||
data_type=jnp.float8_e4m3fn,
|
||||
scale_type=jnp.float8_e8m0fnu,
|
||||
global_scale=None,
|
||||
infer_only=False
|
||||
)
|
||||
|
||||
return [mxfp8_config for _ in range(3)]
|
||||
|
||||
def get_hlo_text(in_shardings, block_scale_configs):
|
||||
mesh_names = ("dp", "tp")
|
||||
devices = np.array(jax.local_devices()[:4]).reshape((2, 2))
|
||||
mesh = Mesh(devices, mesh_names)
|
||||
_, _, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
|
||||
2, 512, 1024, 512, block_scale_configs,
|
||||
)
|
||||
|
||||
with mesh:
|
||||
a_q, b_q, a_scales, b_scales, in_shardings = shard_and_device_put(
|
||||
mesh, in_shardings[0], in_shardings[1], a_q, b_q, a_scales, b_scales
|
||||
)
|
||||
pjit_fn = jax.jit(scaled_matmul_wrapper, in_shardings=in_shardings)
|
||||
hlo = pjit_fn.lower(a_q, b_q, a_scales, b_scales).compile()
|
||||
return hlo.as_text()
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
||||
class ScaledMatmulTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
try:
|
||||
cudnn_version = check_cudnn_version()
|
||||
except RuntimeError as e:
|
||||
self.skipTest(str(e))
|
||||
return
|
||||
if _dtypes.float8_e8m0fnu is None:
|
||||
self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu")
|
||||
if cudnn_version < 90700:
|
||||
self.skipTest("Requires >= cuDNN 9.7.0")
|
||||
if not jtu.is_cuda_compute_capability_at_least("10.0"):
|
||||
self.skipTest("Requires at least Blackwell arch")
|
||||
|
||||
mxfp8_configs = create_mxfp8_configs()
|
||||
|
||||
@jtu.sample_product(
|
||||
in_shardings=sharding_configs,
|
||||
block_scale_configs=[mxfp8_configs,],
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_collectives(self, in_shardings, block_scale_configs):
|
||||
if jtu.device_under_test() != "gpu" or len(jax.local_devices()) < 4:
|
||||
self.skipTest("Partition Test enabled for at least 4 GPUs")
|
||||
|
||||
expected_hlo = sharding_configs[in_shardings][0]
|
||||
hlo_text = get_hlo_text(in_shardings, block_scale_configs)
|
||||
|
||||
hlo_pattern = re.compile(
|
||||
r".*".join([re.escape(x) for x in expected_hlo]), flags=re.DOTALL
|
||||
)
|
||||
self.assertRegex(
|
||||
hlo_text, hlo_pattern, msg=f"Failed to find pattern: {expected_hlo}"
|
||||
)
|
||||
|
||||
@jtu.sample_product(
|
||||
contract=[160, 96],
|
||||
lhs_non_contract=[240, 100],
|
||||
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
|
||||
block_scale_configs=[mxfp8_configs,],
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_scaled_matmul(
|
||||
self, contract, lhs_non_contract, dtype, block_scale_configs,
|
||||
):
|
||||
batch, rhs_non_contract = 2, 128
|
||||
a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
|
||||
batch, lhs_non_contract, contract, rhs_non_contract,
|
||||
block_scale_configs, dtype=dtype,
|
||||
)
|
||||
|
||||
def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type):
|
||||
return scaled_matmul_wrapper(
|
||||
lhs,
|
||||
rhs,
|
||||
lhs_scales,
|
||||
rhs_scales,
|
||||
preferred_element_type=out_type,
|
||||
)
|
||||
|
||||
j_scaled_matmul = jax.jit(partial(wrapper, out_type=dtype))
|
||||
hlo_text = (
|
||||
j_scaled_matmul.lower(a_q, b_q, a_scales, b_scales)
|
||||
.compile()
|
||||
.as_text()
|
||||
)
|
||||
hlo_pattern = re.compile(
|
||||
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
|
||||
)
|
||||
self.assertRegex(hlo_text, hlo_pattern)
|
||||
|
||||
out = j_scaled_matmul(a_q, b_q, a_scales, b_scales)
|
||||
out_ref = np.einsum(
|
||||
"BMK,BNK->BMN", a.astype(jnp.float32), b.astype(jnp.float32)
|
||||
)
|
||||
self.assertArraysAllClose(
|
||||
out, out_ref.astype(dtype), rtol=1e-3, atol=1e-3
|
||||
)
|
||||
|
||||
@jtu.sample_product(
|
||||
in_shardings=sharding_configs,
|
||||
block_scale_configs=[mxfp8_configs,],
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_scaled_matmul_sharded(self, in_shardings, block_scale_configs):
|
||||
if len(jax.local_devices()) < 4:
|
||||
self.skipTest("Require at least 4 devices to run sharding tests.")
|
||||
batch, contract, non_contract = 2, 1024, 256
|
||||
a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
|
||||
batch, non_contract, contract, non_contract, block_scale_configs,
|
||||
)
|
||||
|
||||
devices = np.array(jax.local_devices()[:4])
|
||||
devices = devices.reshape((2, 2))
|
||||
expected_output_spec = sharding_configs[in_shardings][1]
|
||||
|
||||
with Mesh(devices, ("dp", "tp")) as mesh:
|
||||
a_q, b_q, a_scales, b_scales, input_shardings = (
|
||||
shard_and_device_put(
|
||||
mesh,
|
||||
in_shardings[0],
|
||||
in_shardings[1],
|
||||
a_q,
|
||||
b_q,
|
||||
a_scales,
|
||||
b_scales,
|
||||
)
|
||||
)
|
||||
|
||||
args = [a_q, b_q, a_scales, b_scales]
|
||||
j_scaled_matmul = jax.jit(
|
||||
scaled_matmul_wrapper, in_shardings=input_shardings
|
||||
)
|
||||
hlo_compiled = j_scaled_matmul.lower(*args).compile()
|
||||
hlo_pattern = re.compile(
|
||||
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
|
||||
)
|
||||
self.assertRegex(hlo_compiled.as_text(), hlo_pattern)
|
||||
|
||||
j_ref = jax.jit(
|
||||
partial(
|
||||
jax.lax.dot_general,
|
||||
dimension_numbers=(([2], [2]), ([0], [0])),
|
||||
),
|
||||
in_shardings=input_shardings[:2],
|
||||
)
|
||||
|
||||
out = j_scaled_matmul(*args)
|
||||
out_ref = j_ref(a, b)
|
||||
expected_output_sharding = NamedSharding(
|
||||
mesh=mesh, spec=expected_output_spec
|
||||
)
|
||||
self.assertArraysAllClose(out, out_ref, rtol=1e-3, atol=1e-3)
|
||||
self.assertTrue(
|
||||
out.sharding.is_equivalent_to(expected_output_sharding, out.ndim)
|
||||
)
|
||||
|
||||
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
||||
class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
try:
|
||||
cudnn_version = check_cudnn_version()
|
||||
except RuntimeError as e:
|
||||
self.skipTest(str(e))
|
||||
return
|
||||
if _dtypes.float8_e8m0fnu is None:
|
||||
self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu")
|
||||
if cudnn_version < 90700:
|
||||
self.skipTest("Requires >= cuDNN 9.7.0")
|
||||
if not jtu.is_cuda_compute_capability_at_least("10.0"):
|
||||
self.skipTest("Requires at least Blackwell arch")
|
||||
|
||||
block_scale_configs = create_mxfp8_configs()
|
||||
|
||||
@jtu.sample_product(
|
||||
configs=[
|
||||
# a_shape, b_shape, dimension_numbers, is_training
|
||||
((1, 32), (2, 32), (([1], [1]), ([], [])), False),
|
||||
((30, 64), (100, 64), (([1], [1]), ([], [])), False),
|
||||
((192, 96), (160, 96), (([1], [1]), ([], [])), True),
|
||||
((64, 128, 4), (128, 128), (([1], [0]), ([], [])), True),
|
||||
((1, 128, 1024), (1, 1024, 128), (([2], [1]), ([0], [0])), True),
|
||||
(
|
||||
(1, 128, 128, 2),
|
||||
(128, 1, 2, 128),
|
||||
(([2], [0]), ([0, 3], [1, 2])),
|
||||
True,
|
||||
),
|
||||
],
|
||||
output_type=[jnp.float16, jnp.bfloat16, jnp.float32],
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_dot_general(self, configs, output_type):
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize,
|
||||
scale=jnp.ones((1,)),
|
||||
compute_dtype=jnp.float32,
|
||||
)
|
||||
k1, k2 = jax.random.split(jax.random.key(0), 2)
|
||||
|
||||
a_shape, b_shape, dimension_numbers, is_training = configs
|
||||
a = cast_to_representable(
|
||||
jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type),
|
||||
self.block_scale_configs[0].data_type,
|
||||
)
|
||||
b = cast_to_representable(
|
||||
jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type),
|
||||
self.block_scale_configs[1].data_type,
|
||||
)
|
||||
|
||||
scaled_dot_general = partial(
|
||||
scaled_dot_general_wrapper,
|
||||
configs=self.block_scale_configs
|
||||
)
|
||||
def fwd(a, b, is_ref=False):
|
||||
fn = jax.lax.dot_general if is_ref else scaled_dot_general
|
||||
y = fn(a, b, dimension_numbers,
|
||||
preferred_element_type=output_type)
|
||||
return jnp.sum(y)
|
||||
|
||||
if is_training:
|
||||
j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1]))
|
||||
|
||||
j_train_ref = jax.jit(
|
||||
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1])
|
||||
)
|
||||
out, (x_grad, w_grad) = j_train(a, b)
|
||||
out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b)
|
||||
|
||||
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
||||
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
|
||||
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
|
||||
else:
|
||||
j_inference = jax.jit(fwd)
|
||||
j_inference_ref = jax.jit(partial(fwd, is_ref=True))
|
||||
out = j_inference(a, b)
|
||||
out_ref = j_inference_ref(a, b)
|
||||
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@jtu.sample_product(in_shardings=sharding_configs)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_dot_general_sharded(self, in_shardings):
|
||||
if len(jax.local_devices()) < 4:
|
||||
self.skipTest("Require at least 4 devices to run sharding tests.")
|
||||
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize,
|
||||
scale=jnp.ones((1,)),
|
||||
compute_dtype=jnp.float32,
|
||||
)
|
||||
|
||||
dimension_numbers = (([2], [2]), ([0], [0]))
|
||||
a_shape = (2, 128, 512)
|
||||
b_shape = (2, 256, 512)
|
||||
|
||||
k1, k2 = jax.random.split(jax.random.key(0), 2)
|
||||
a = cast_to_representable(
|
||||
jax.random.uniform(k1, a_shape, minval=-1.0),
|
||||
self.block_scale_configs[0].data_type,
|
||||
)
|
||||
b = cast_to_representable(
|
||||
jax.random.uniform(k2, b_shape, minval=-1.0),
|
||||
self.block_scale_configs[1].data_type,
|
||||
)
|
||||
|
||||
scaled_dot_general = partial(
|
||||
scaled_dot_general_wrapper,
|
||||
configs=self.block_scale_configs
|
||||
)
|
||||
def fwd(a, b, is_ref=False):
|
||||
fn = jax.lax.dot_general if is_ref else scaled_dot_general
|
||||
y = fn(a, b, dimension_numbers)
|
||||
# Use a little complex loss function to avoid constant grads, whose
|
||||
# sharding info might be optimized off and then cause issue with the
|
||||
# custom scaled_matmul op.
|
||||
return jnp.sum(jnp.tanh(y))
|
||||
|
||||
devices = np.array(jax.local_devices()[:4])
|
||||
devices = devices.reshape((2, 2))
|
||||
with Mesh(devices, ("dp", "tp")) as mesh:
|
||||
a, b, input_shardings = (
|
||||
shard_and_device_put(
|
||||
mesh,
|
||||
in_shardings[0],
|
||||
in_shardings[1],
|
||||
a,
|
||||
b,
|
||||
)
|
||||
)
|
||||
|
||||
j_train = jax.jit(jax.value_and_grad(partial(fwd), argnums=[0, 1]),
|
||||
in_shardings=input_shardings)
|
||||
hlo_text = j_train.lower(a, b).compile().as_text()
|
||||
hlo_pattern = re.compile(
|
||||
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
|
||||
)
|
||||
|
||||
j_train_ref = jax.jit(
|
||||
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1]),
|
||||
in_shardings=input_shardings
|
||||
)
|
||||
out, (x_grad, w_grad) = j_train(a, b)
|
||||
out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b)
|
||||
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
|
||||
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
|
||||
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
configs=[
|
||||
((1, 128, 256), (1, 128, 256), (0, 0, 0)),
|
||||
((2, 128, 128), (2, 128, 128), (0, 0, 0)),
|
||||
((2, 128, 128), (128, 2, 128), (0, 1, 2)),
|
||||
]
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_dot_general_vmap(self, configs):
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize,
|
||||
scale=jnp.ones((1,)),
|
||||
compute_dtype=jnp.float32,
|
||||
)
|
||||
k1, k2 = jax.random.split(jax.random.key(0), 2)
|
||||
|
||||
a_shape, b_shape, vmap_axes = configs
|
||||
a_axis, b_axis, o_axis = vmap_axes
|
||||
dimension_numbers = (([1], [1]), ([], []))
|
||||
|
||||
a = cast_to_representable(
|
||||
jax.random.uniform(k1, a_shape, minval=-1.0),
|
||||
self.block_scale_configs[0].data_type,
|
||||
)
|
||||
b = cast_to_representable(
|
||||
jax.random.uniform(k2, b_shape, minval=-1.0),
|
||||
self.block_scale_configs[1].data_type,
|
||||
)
|
||||
|
||||
scaled_dot_general = partial(
|
||||
scaled_dot_general_wrapper,
|
||||
configs=self.block_scale_configs
|
||||
)
|
||||
def fwd(a, b, is_ref=False):
|
||||
fn = jax.vmap(
|
||||
jax.lax.dot_general if is_ref else scaled_dot_general,
|
||||
in_axes=(a_axis, b_axis, None),
|
||||
out_axes=o_axis,
|
||||
)
|
||||
y = fn(a, b, dimension_numbers)
|
||||
return jnp.sum(y)
|
||||
|
||||
j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1]))
|
||||
j_train_ref = jax.jit(
|
||||
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1])
|
||||
)
|
||||
out, (x_grad, w_grad) = j_train(a, b)
|
||||
out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b)
|
||||
|
||||
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e2)
|
||||
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
|
||||
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user