mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax] Improve naming of DotAlgorithmPreset
properties and simplify return types.
PiperOrigin-RevId: 702317395
This commit is contained in:
parent
0bb68f6ad2
commit
abf8f43007
@ -876,7 +876,7 @@ class DotAlgorithmPreset(enum.Enum):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def lhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
|
||||
def supported_lhs_types(self) -> tuple[DTypeLike, ...] | None:
|
||||
match self:
|
||||
case (
|
||||
DotAlgorithmPreset.DEFAULT
|
||||
@ -887,7 +887,7 @@ class DotAlgorithmPreset(enum.Enum):
|
||||
):
|
||||
return None
|
||||
case DotAlgorithmPreset.F16_F16_F16 | DotAlgorithmPreset.F16_F16_F32:
|
||||
return np.float16
|
||||
return (np.float16,)
|
||||
case (
|
||||
DotAlgorithmPreset.BF16_BF16_BF16 |
|
||||
DotAlgorithmPreset.BF16_BF16_F32
|
||||
@ -897,13 +897,13 @@ class DotAlgorithmPreset(enum.Enum):
|
||||
# type. If not, we explicitly cast to bfloat16.
|
||||
return (dtypes.bfloat16, np.float32)
|
||||
case DotAlgorithmPreset.F64_F64_F64:
|
||||
return np.float64
|
||||
return (np.float64,)
|
||||
case _:
|
||||
return np.float32
|
||||
return (np.float32,)
|
||||
|
||||
@property
|
||||
def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
|
||||
return self.lhs_precision_type
|
||||
def supported_rhs_types(self) -> tuple[DTypeLike, ...] | None:
|
||||
return self.supported_lhs_types
|
||||
|
||||
@property
|
||||
def accumulation_type(self) -> DTypeLike | None:
|
||||
@ -927,12 +927,19 @@ class DotAlgorithmPreset(enum.Enum):
|
||||
def supported_output_types(self) -> tuple[DTypeLike, ...] | None:
|
||||
match self:
|
||||
case (
|
||||
DotAlgorithmPreset.ANY_F8_ANY_F8_F32 |
|
||||
DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM
|
||||
DotAlgorithmPreset.ANY_F8_ANY_F8_F32
|
||||
| DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM
|
||||
):
|
||||
return (np.float32, np.float16, dtypes.bfloat16, dtypes.float8_e4m3fn,
|
||||
dtypes.float8_e5m2, dtypes.float8_e5m2fnuz,
|
||||
dtypes.float8_e4m3fnuz, dtypes.float8_e4m3b11fnuz)
|
||||
return (
|
||||
np.float32,
|
||||
np.float16,
|
||||
dtypes.bfloat16,
|
||||
dtypes.float8_e4m3fn,
|
||||
dtypes.float8_e5m2,
|
||||
dtypes.float8_e5m2fnuz,
|
||||
dtypes.float8_e4m3fnuz,
|
||||
dtypes.float8_e4m3b11fnuz,
|
||||
)
|
||||
case DotAlgorithmPreset.F16_F16_F32:
|
||||
return (np.float32, np.float16)
|
||||
case _:
|
||||
@ -3699,35 +3706,32 @@ def get_algorithm_compute_types(
|
||||
rhs_dtype: DTypeLike,
|
||||
out_dtype: DTypeLike | None = None,
|
||||
) -> tuple[DTypeLike | None, DTypeLike | None, DTypeLike | None]:
|
||||
def maybe_convert_dtype(input_dtype, target_dtype):
|
||||
if target_dtype is None:
|
||||
return input_dtype
|
||||
if not isinstance(target_dtype, tuple):
|
||||
target_dtype = (target_dtype,)
|
||||
if np.dtype(input_dtype) in map(np.dtype, target_dtype):
|
||||
return input_dtype
|
||||
return target_dtype[0]
|
||||
if isinstance(algorithm, DotAlgorithm):
|
||||
return (
|
||||
algorithm.lhs_precision_type,
|
||||
algorithm.rhs_precision_type,
|
||||
algorithm.accumulation_type,
|
||||
)
|
||||
|
||||
supported_output_types = algorithm.supported_output_types
|
||||
|
||||
if algorithm == DotAlgorithmPreset.BF16_BF16_F32:
|
||||
lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type)
|
||||
rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type)
|
||||
if np.dtype(lhs_dtype) == dtypes.bfloat16:
|
||||
out_dtype = maybe_convert_dtype(out_dtype,
|
||||
(np.float32, dtypes.bfloat16))
|
||||
else:
|
||||
out_dtype = maybe_convert_dtype(out_dtype, np.float32)
|
||||
return lhs_dtype, rhs_dtype, out_dtype
|
||||
else:
|
||||
if isinstance(algorithm, DotAlgorithmPreset):
|
||||
supported_output_types = algorithm.supported_output_types
|
||||
else:
|
||||
supported_output_types = (algorithm.accumulation_type,)
|
||||
# If dtype is anything other than float32, it will be cast to bfloat16.
|
||||
if np.dtype(lhs_dtype) != np.float32:
|
||||
supported_output_types = (np.float32, dtypes.bfloat16)
|
||||
|
||||
return (
|
||||
maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type),
|
||||
maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type),
|
||||
maybe_convert_dtype(out_dtype, supported_output_types),
|
||||
)
|
||||
def maybe_convert_dtype(input_dtype, target_dtypes):
|
||||
if target_dtypes is None:
|
||||
return input_dtype
|
||||
if np.dtype(input_dtype) in map(np.dtype, target_dtypes):
|
||||
return input_dtype
|
||||
return target_dtypes[0]
|
||||
|
||||
return (
|
||||
maybe_convert_dtype(lhs_dtype, algorithm.supported_lhs_types),
|
||||
maybe_convert_dtype(rhs_dtype, algorithm.supported_rhs_types),
|
||||
maybe_convert_dtype(out_dtype, supported_output_types),
|
||||
)
|
||||
|
||||
|
||||
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
|
Loading…
x
Reference in New Issue
Block a user