mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Remove the blocking for float16 dot on CPU platform to take advantage of CPU
platforms supporting float16 matmul computation for performance optimization. With this PR change, JAX will allow dot float16 HLO being created. When the HLO modules are processed during cpu compile stage in open xla, the ChangeOpDataType pass will upcast the dot to float type if the CPU platform does not support float16 computation, but for the platform supporting float16 computation, dot will stay as float16 type for execution.
This commit is contained in:
parent
4a7b293bd9
commit
b37f51487d
@ -2959,16 +2959,6 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
|
||||
lhs_dtype = rhs_dtype = aval_out.dtype
|
||||
|
||||
# TODO(b/195364460): Work around slow XLA/CPU implementation of float16 matmul
|
||||
if platform == "cpu":
|
||||
if lhs_dtype == np.float16:
|
||||
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
|
||||
core.ShapedArray(lhs_aval.shape, np.float32))
|
||||
|
||||
if rhs_dtype == np.float16:
|
||||
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
|
||||
core.ShapedArray(rhs_aval.shape, np.float32))
|
||||
|
||||
|
||||
dot_dnums = hlo.DotDimensionNumbers.get(
|
||||
lhs_batching_dimensions=list(lhs_batch),
|
||||
|
Loading…
x
Reference in New Issue
Block a user