[Mosaic TPU] Support bf16 div if HW does not directly support.

PiperOrigin-RevId: 726212286
This commit is contained in:
Jevin Jiang 2025-02-12 15:03:16 -08:00 committed by jax authors
parent 153a7cf913
commit 876668faa1
2 changed files with 34 additions and 14 deletions

View File

@ -677,18 +677,18 @@ const llvm::StringMap<canonicalize_rule_type> &rules() {
return *rules;
}
const llvm::StringSet<> &elementwise_convertible_ops() {
static auto ops = new llvm::StringSet<>{arith::MulFOp::getOperationName(),
arith::DivFOp::getOperationName(),
arith::AddFOp::getOperationName(),
arith::SubFOp::getOperationName(),
arith::MaximumFOp::getOperationName(),
arith::MinimumFOp::getOperationName(),
math::PowFOp::getOperationName(),
math::TanhOp::getOperationName(),
math::ExpOp::getOperationName(),
math::LogOp::getOperationName()};
return *ops;
bool need_elementwise_canonicalization(CanonicalizeContext ctx, Operation &op) {
if (isa<arith::DivFOp>(op)) {
auto vec_ty = dyn_cast<VectorType>(op.getOperand(0).getType());
if (vec_ty && vec_ty.getElementType().isBF16() &&
ctx.hardware_generation >= 4) {
return false;
}
return true;
}
return isa<arith::MulFOp, arith::AddFOp, arith::SubFOp, arith::MaximumFOp,
arith::MinimumFOp, math::PowFOp, math::TanhOp, math::ExpOp,
math::LogOp>(op);
}
class MosaicCanonicalizer {
@ -730,8 +730,7 @@ class MosaicCanonicalizer {
}
}
}
if (elementwise_convertible_ops().contains(
any_op.getName().getStringRef())) {
if (need_elementwise_canonicalization(ctx, any_op)) {
return canonicalize_elementwise(ctx, any_op);
}
if (auto rule_it = rules().find(any_op.getName().getStringRef());

View File

@ -417,6 +417,27 @@ class OpsTest(PallasBaseTest):
expected = np.take_along_axis(x, idx, axis=axis)
np.testing.assert_array_equal(actual, expected)
@parameterized.product(dtype=[jnp.float32, jnp.bfloat16])
def test_float_div(self, dtype):
if not jtu.if_cloud_tpu_at_least(2025, 2, 13):
self.skipTest("Requires libtpu built after 2025-02-13")
if not jtu.is_device_tpu_at_least(version=4):
self.skipTest("Requires TPUv4+")
kwargs = {}
if jtu.get_tpu_version() == 6:
kwargs.update(dict(rtol=1e-2))
def kernel(x, y, out):
out[:] = jax.lax.div(x[:], y[:])
run = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((8, 128), dtype),
)
k1, k2 = jax.random.split(jax.random.key(1234), 2)
x = jax.random.normal(k1, (8, 128), dtype=dtype)
y = jax.random.normal(k2, (8, 128), dtype=dtype)
np.testing.assert_allclose(run(x, y), jax.lax.div(x, y), **kwargs)
class OpsInterpretTest(OpsTest):
INTERPRET = True