mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
[pallas:triton] Fixed dispatch tablee for lax.pow_p
PiperOrigin-RevId: 722817510
This commit is contained in:
parent
59a3552ae6
commit
f58207a28d
@ -786,13 +786,13 @@ pow_dispatch_table = _make_dispatch_table(
|
||||
_Extern([jnp.float64, jnp.int32], "__nv_powi", jnp.float64),
|
||||
_Fallback(
|
||||
[jnp.floating, jnp.integer],
|
||||
lambda ctx, x, y: math_dialect.fpowi(x, y),
|
||||
math_dialect.fpowi
|
||||
),
|
||||
_Extern([jnp.float32, jnp.float32], "__nv_powf", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__nv_pow", jnp.float64),
|
||||
_Fallback(
|
||||
[jnp.floating, jnp.floating],
|
||||
lambda ctx, x, y: math_dialect.powf(x, y),
|
||||
math_dialect.powf
|
||||
),
|
||||
],
|
||||
rocm=[
|
||||
@ -800,13 +800,13 @@ pow_dispatch_table = _make_dispatch_table(
|
||||
_Extern([jnp.float64, jnp.int32], "__ocml_pown_f64", jnp.float64),
|
||||
_Fallback(
|
||||
[jnp.floating, jnp.integer],
|
||||
lambda ctx, x, y: math_dialect.fpowi(x, y),
|
||||
math_dialect.fpowi
|
||||
),
|
||||
_Extern([jnp.float32, jnp.float32], "__ocml_pow_f32", jnp.float32),
|
||||
_Extern([jnp.float64, jnp.float64], "__ocml_pow_f64", jnp.float64),
|
||||
_Fallback(
|
||||
[jnp.floating, jnp.floating],
|
||||
lambda ctx, x, y: math_dialect.powf(x, y),
|
||||
math_dialect.powf
|
||||
),
|
||||
],
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user