1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

[pallas:triton] Fixed dispatch tablee for lax.pow_p

PiperOrigin-RevId: 722817510
This commit is contained in:
Sergei Lebedev 2025-02-03 15:17:16 -08:00 committed by jax authors
parent 59a3552ae6
commit f58207a28d

@ -786,13 +786,13 @@ pow_dispatch_table = _make_dispatch_table(
_Extern([jnp.float64, jnp.int32], "__nv_powi", jnp.float64),
_Fallback(
[jnp.floating, jnp.integer],
lambda ctx, x, y: math_dialect.fpowi(x, y),
math_dialect.fpowi
),
_Extern([jnp.float32, jnp.float32], "__nv_powf", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__nv_pow", jnp.float64),
_Fallback(
[jnp.floating, jnp.floating],
lambda ctx, x, y: math_dialect.powf(x, y),
math_dialect.powf
),
],
rocm=[
@ -800,13 +800,13 @@ pow_dispatch_table = _make_dispatch_table(
_Extern([jnp.float64, jnp.int32], "__ocml_pown_f64", jnp.float64),
_Fallback(
[jnp.floating, jnp.integer],
lambda ctx, x, y: math_dialect.fpowi(x, y),
math_dialect.fpowi
),
_Extern([jnp.float32, jnp.float32], "__ocml_pow_f32", jnp.float32),
_Extern([jnp.float64, jnp.float64], "__ocml_pow_f64", jnp.float64),
_Fallback(
[jnp.floating, jnp.floating],
lambda ctx, x, y: math_dialect.powf(x, y),
math_dialect.powf
),
],
)