mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Added a Pallas GPU test for jnp.invert
PiperOrigin-RevId: 615369570
This commit is contained in:
parent
926e673f61
commit
f0c5051004
@ -517,10 +517,7 @@ triton_lowering_rules[lax.cumsum_p] = _cumsum_lowering_rule
|
||||
|
||||
def _not_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
if not np.issubdtype(x_aval.dtype, jnp.integer):
|
||||
raise NotImplementedError(f"unsupported type: {x_aval.dtype}")
|
||||
one = _full(x.type, 0xFFFFFFFFFFFFFFFF)
|
||||
return arith_dialect.xori(x, one)
|
||||
return arith_dialect.xori(x, _full(x.type, ~x_aval.dtype.type(0)))
|
||||
|
||||
|
||||
triton_lowering_rules[lax.not_p] = _not_lowering_rule
|
||||
|
@ -1532,7 +1532,7 @@ class PallasOpsTest(PallasTest):
|
||||
# fmt: on
|
||||
["float32", "float64"]
|
||||
),
|
||||
([lax.population_count, lax.clz], ["int32", "int64"]),
|
||||
([lax.population_count, lax.clz, jnp.invert], ["int32", "int64"]),
|
||||
]
|
||||
|
||||
@parameterized.named_parameters(
|
||||
|
Loading…
x
Reference in New Issue
Block a user