From 876668faa1fc7982fb53a0e4aa00a1726f93c196 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Wed, 12 Feb 2025 15:03:16 -0800 Subject: [PATCH] [Mosaic TPU] Support bf16 div if HW does not directly support. PiperOrigin-RevId: 726212286 --- .../tpu/transforms/canonicalize_mosaic.cc | 27 +++++++++---------- tests/pallas/tpu_ops_test.py | 21 +++++++++++++++ 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 098a69cc0..5efbdb9cb 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -677,18 +677,18 @@ const llvm::StringMap &rules() { return *rules; } -const llvm::StringSet<> &elementwise_convertible_ops() { - static auto ops = new llvm::StringSet<>{arith::MulFOp::getOperationName(), - arith::DivFOp::getOperationName(), - arith::AddFOp::getOperationName(), - arith::SubFOp::getOperationName(), - arith::MaximumFOp::getOperationName(), - arith::MinimumFOp::getOperationName(), - math::PowFOp::getOperationName(), - math::TanhOp::getOperationName(), - math::ExpOp::getOperationName(), - math::LogOp::getOperationName()}; - return *ops; +bool need_elementwise_canonicalization(CanonicalizeContext ctx, Operation &op) { + if (isa(op)) { + auto vec_ty = dyn_cast(op.getOperand(0).getType()); + if (vec_ty && vec_ty.getElementType().isBF16() && + ctx.hardware_generation >= 4) { + return false; + } + return true; + } + return isa(op); } class MosaicCanonicalizer { @@ -730,8 +730,7 @@ class MosaicCanonicalizer { } } } - if (elementwise_convertible_ops().contains( - any_op.getName().getStringRef())) { + if (need_elementwise_canonicalization(ctx, any_op)) { return canonicalize_elementwise(ctx, any_op); } if (auto rule_it = rules().find(any_op.getName().getStringRef()); diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 73170aa7a..5c4e60726 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -417,6 +417,27 @@ class OpsTest(PallasBaseTest): expected = np.take_along_axis(x, idx, axis=axis) np.testing.assert_array_equal(actual, expected) + @parameterized.product(dtype=[jnp.float32, jnp.bfloat16]) + def test_float_div(self, dtype): + if not jtu.if_cloud_tpu_at_least(2025, 2, 13): + self.skipTest("Requires libtpu built after 2025-02-13") + if not jtu.is_device_tpu_at_least(version=4): + self.skipTest("Requires TPUv4+") + kwargs = {} + if jtu.get_tpu_version() == 6: + kwargs.update(dict(rtol=1e-2)) + def kernel(x, y, out): + out[:] = jax.lax.div(x[:], y[:]) + + run = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 128), dtype), + ) + k1, k2 = jax.random.split(jax.random.key(1234), 2) + x = jax.random.normal(k1, (8, 128), dtype=dtype) + y = jax.random.normal(k2, (8, 128), dtype=dtype) + np.testing.assert_allclose(run(x, y), jax.lax.div(x, y), **kwargs) + class OpsInterpretTest(OpsTest): INTERPRET = True