Remove the blocking for float16 dot on CPU platform to take advantage of CPU

platforms supporting float16 matmul computation for performance optimization.
With this PR change, JAX will allow dot float16 HLO being created. When the
HLO modules are processed during cpu compile stage in open xla, the
ChangeOpDataType pass will upcast the dot to float type if the CPU platform
does not support float16 computation, but for the platform supporting float16
computation, dot will stay as float16 type for execution.
This commit is contained in:
Yimei Sun 2024-06-23 23:51:30 -07:00
parent 4a7b293bd9
commit b37f51487d

View File

@ -2959,16 +2959,6 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
lhs_dtype = rhs_dtype = aval_out.dtype
# TODO(b/195364460): Work around slow XLA/CPU implementation of float16 matmul
if platform == "cpu":
if lhs_dtype == np.float16:
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
core.ShapedArray(lhs_aval.shape, np.float32))
if rhs_dtype == np.float16:
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
core.ShapedArray(rhs_aval.shape, np.float32))
dot_dnums = hlo.DotDimensionNumbers.get(
lhs_batching_dimensions=list(lhs_batch),