mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Temporarily switch triton.compat to use Triton APIs for math and semantic operations
This is only meant as a short-term fix to unblock internal users. PiperOrigin-RevId: 602707085
This commit is contained in:
parent
66308c30ad
commit
9e76e380cc
@ -1287,6 +1287,11 @@ class math:
|
||||
(float64,): ("__nv_rsqrt", float64),
|
||||
})
|
||||
|
||||
# TODO(slebedev): Fix the implementation above and remove this.
|
||||
for name in vars(math):
|
||||
if not name.startswith("__") and hasattr(tl.math, name):
|
||||
setattr(math, name, wrap_with_builder(getattr(tl.math, name)))
|
||||
|
||||
|
||||
class semantic:
|
||||
cast = wrap_with_builder(tl.semantic.cast)
|
||||
@ -1560,3 +1565,8 @@ class semantic:
|
||||
_bool_block_like(x),
|
||||
)
|
||||
raise NotImplementedError(f"unsupported dtypes: {x.dtype} and {y.dtype}")
|
||||
|
||||
# TODO(slebedev): Fix the implementation above and remove this.
|
||||
for name in vars(semantic):
|
||||
if not name.startswith("__") and hasattr(tl.semantic, name):
|
||||
setattr(semantic, name, wrap_with_builder(getattr(tl.semantic, name)))
|
||||
|
Loading…
x
Reference in New Issue
Block a user