mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[Mosaic TPU] Support bf16 div if HW does not directly support.
PiperOrigin-RevId: 726212286
This commit is contained in:
parent
153a7cf913
commit
876668faa1
@ -677,18 +677,18 @@ const llvm::StringMap<canonicalize_rule_type> &rules() {
|
||||
return *rules;
|
||||
}
|
||||
|
||||
const llvm::StringSet<> &elementwise_convertible_ops() {
|
||||
static auto ops = new llvm::StringSet<>{arith::MulFOp::getOperationName(),
|
||||
arith::DivFOp::getOperationName(),
|
||||
arith::AddFOp::getOperationName(),
|
||||
arith::SubFOp::getOperationName(),
|
||||
arith::MaximumFOp::getOperationName(),
|
||||
arith::MinimumFOp::getOperationName(),
|
||||
math::PowFOp::getOperationName(),
|
||||
math::TanhOp::getOperationName(),
|
||||
math::ExpOp::getOperationName(),
|
||||
math::LogOp::getOperationName()};
|
||||
return *ops;
|
||||
bool need_elementwise_canonicalization(CanonicalizeContext ctx, Operation &op) {
|
||||
if (isa<arith::DivFOp>(op)) {
|
||||
auto vec_ty = dyn_cast<VectorType>(op.getOperand(0).getType());
|
||||
if (vec_ty && vec_ty.getElementType().isBF16() &&
|
||||
ctx.hardware_generation >= 4) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return isa<arith::MulFOp, arith::AddFOp, arith::SubFOp, arith::MaximumFOp,
|
||||
arith::MinimumFOp, math::PowFOp, math::TanhOp, math::ExpOp,
|
||||
math::LogOp>(op);
|
||||
}
|
||||
|
||||
class MosaicCanonicalizer {
|
||||
@ -730,8 +730,7 @@ class MosaicCanonicalizer {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (elementwise_convertible_ops().contains(
|
||||
any_op.getName().getStringRef())) {
|
||||
if (need_elementwise_canonicalization(ctx, any_op)) {
|
||||
return canonicalize_elementwise(ctx, any_op);
|
||||
}
|
||||
if (auto rule_it = rules().find(any_op.getName().getStringRef());
|
||||
|
@ -417,6 +417,27 @@ class OpsTest(PallasBaseTest):
|
||||
expected = np.take_along_axis(x, idx, axis=axis)
|
||||
np.testing.assert_array_equal(actual, expected)
|
||||
|
||||
@parameterized.product(dtype=[jnp.float32, jnp.bfloat16])
|
||||
def test_float_div(self, dtype):
|
||||
if not jtu.if_cloud_tpu_at_least(2025, 2, 13):
|
||||
self.skipTest("Requires libtpu built after 2025-02-13")
|
||||
if not jtu.is_device_tpu_at_least(version=4):
|
||||
self.skipTest("Requires TPUv4+")
|
||||
kwargs = {}
|
||||
if jtu.get_tpu_version() == 6:
|
||||
kwargs.update(dict(rtol=1e-2))
|
||||
def kernel(x, y, out):
|
||||
out[:] = jax.lax.div(x[:], y[:])
|
||||
|
||||
run = pl.pallas_call(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((8, 128), dtype),
|
||||
)
|
||||
k1, k2 = jax.random.split(jax.random.key(1234), 2)
|
||||
x = jax.random.normal(k1, (8, 128), dtype=dtype)
|
||||
y = jax.random.normal(k2, (8, 128), dtype=dtype)
|
||||
np.testing.assert_allclose(run(x, y), jax.lax.div(x, y), **kwargs)
|
||||
|
||||
|
||||
class OpsInterpretTest(OpsTest):
|
||||
INTERPRET = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user