Temporarily switch triton.compat to use Triton APIs for math and semantic operations

This is only meant as a short-term fix to unblock internal users.

PiperOrigin-RevId: 602707085
This commit is contained in:
Sergei Lebedev 2024-01-30 06:29:24 -08:00 committed by jax authors
parent 66308c30ad
commit 9e76e380cc

View File

@ -1287,6 +1287,11 @@ class math:
(float64,): ("__nv_rsqrt", float64),
})
# TODO(slebedev): Fix the implementation above and remove this.
for name in vars(math):
if not name.startswith("__") and hasattr(tl.math, name):
setattr(math, name, wrap_with_builder(getattr(tl.math, name)))
class semantic:
cast = wrap_with_builder(tl.semantic.cast)
@ -1560,3 +1565,8 @@ class semantic:
_bool_block_like(x),
)
raise NotImplementedError(f"unsupported dtypes: {x.dtype} and {y.dtype}")
# TODO(slebedev): Fix the implementation above and remove this.
for name in vars(semantic):
if not name.startswith("__") and hasattr(tl.semantic, name):
setattr(semantic, name, wrap_with_builder(getattr(tl.semantic, name)))