[jax] Improve naming of DotAlgorithmPreset properties and simplify return types.

PiperOrigin-RevId: 702317395
This commit is contained in:
Chris Jones 2024-12-03 06:25:55 -08:00 committed by jax authors
parent 0bb68f6ad2
commit abf8f43007

View File

@ -876,7 +876,7 @@ class DotAlgorithmPreset(enum.Enum):
return self.name
@property
def lhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
def supported_lhs_types(self) -> tuple[DTypeLike, ...] | None:
match self:
case (
DotAlgorithmPreset.DEFAULT
@ -887,7 +887,7 @@ class DotAlgorithmPreset(enum.Enum):
):
return None
case DotAlgorithmPreset.F16_F16_F16 | DotAlgorithmPreset.F16_F16_F32:
return np.float16
return (np.float16,)
case (
DotAlgorithmPreset.BF16_BF16_BF16 |
DotAlgorithmPreset.BF16_BF16_F32
@ -897,13 +897,13 @@ class DotAlgorithmPreset(enum.Enum):
# type. If not, we explicitly cast to bfloat16.
return (dtypes.bfloat16, np.float32)
case DotAlgorithmPreset.F64_F64_F64:
return np.float64
return (np.float64,)
case _:
return np.float32
return (np.float32,)
@property
def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
return self.lhs_precision_type
def supported_rhs_types(self) -> tuple[DTypeLike, ...] | None:
return self.supported_lhs_types
@property
def accumulation_type(self) -> DTypeLike | None:
@ -927,12 +927,19 @@ class DotAlgorithmPreset(enum.Enum):
def supported_output_types(self) -> tuple[DTypeLike, ...] | None:
match self:
case (
DotAlgorithmPreset.ANY_F8_ANY_F8_F32 |
DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM
DotAlgorithmPreset.ANY_F8_ANY_F8_F32
| DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM
):
return (np.float32, np.float16, dtypes.bfloat16, dtypes.float8_e4m3fn,
dtypes.float8_e5m2, dtypes.float8_e5m2fnuz,
dtypes.float8_e4m3fnuz, dtypes.float8_e4m3b11fnuz)
return (
np.float32,
np.float16,
dtypes.bfloat16,
dtypes.float8_e4m3fn,
dtypes.float8_e5m2,
dtypes.float8_e5m2fnuz,
dtypes.float8_e4m3fnuz,
dtypes.float8_e4m3b11fnuz,
)
case DotAlgorithmPreset.F16_F16_F32:
return (np.float32, np.float16)
case _:
@ -3699,35 +3706,32 @@ def get_algorithm_compute_types(
rhs_dtype: DTypeLike,
out_dtype: DTypeLike | None = None,
) -> tuple[DTypeLike | None, DTypeLike | None, DTypeLike | None]:
def maybe_convert_dtype(input_dtype, target_dtype):
if target_dtype is None:
return input_dtype
if not isinstance(target_dtype, tuple):
target_dtype = (target_dtype,)
if np.dtype(input_dtype) in map(np.dtype, target_dtype):
return input_dtype
return target_dtype[0]
if isinstance(algorithm, DotAlgorithm):
return (
algorithm.lhs_precision_type,
algorithm.rhs_precision_type,
algorithm.accumulation_type,
)
supported_output_types = algorithm.supported_output_types
if algorithm == DotAlgorithmPreset.BF16_BF16_F32:
lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type)
rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type)
if np.dtype(lhs_dtype) == dtypes.bfloat16:
out_dtype = maybe_convert_dtype(out_dtype,
(np.float32, dtypes.bfloat16))
else:
out_dtype = maybe_convert_dtype(out_dtype, np.float32)
return lhs_dtype, rhs_dtype, out_dtype
else:
if isinstance(algorithm, DotAlgorithmPreset):
supported_output_types = algorithm.supported_output_types
else:
supported_output_types = (algorithm.accumulation_type,)
# If dtype is anything other than float32, it will be cast to bfloat16.
if np.dtype(lhs_dtype) != np.float32:
supported_output_types = (np.float32, dtypes.bfloat16)
return (
maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type),
maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type),
maybe_convert_dtype(out_dtype, supported_output_types),
)
def maybe_convert_dtype(input_dtype, target_dtypes):
if target_dtypes is None:
return input_dtype
if np.dtype(input_dtype) in map(np.dtype, target_dtypes):
return input_dtype
return target_dtypes[0]
return (
maybe_convert_dtype(lhs_dtype, algorithm.supported_lhs_types),
maybe_convert_dtype(rhs_dtype, algorithm.supported_rhs_types),
maybe_convert_dtype(out_dtype, supported_output_types),
)
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,