Added a Pallas GPU test for jnp.invert

PiperOrigin-RevId: 615369570
This commit is contained in:
Sergei Lebedev 2024-03-13 04:47:54 -07:00 committed by jax authors
parent 926e673f61
commit f0c5051004
2 changed files with 2 additions and 5 deletions

View File

@ -517,10 +517,7 @@ triton_lowering_rules[lax.cumsum_p] = _cumsum_lowering_rule
def _not_lowering_rule(ctx: LoweringRuleContext, x):
[x_aval] = ctx.avals_in
if not np.issubdtype(x_aval.dtype, jnp.integer):
raise NotImplementedError(f"unsupported type: {x_aval.dtype}")
one = _full(x.type, 0xFFFFFFFFFFFFFFFF)
return arith_dialect.xori(x, one)
return arith_dialect.xori(x, _full(x.type, ~x_aval.dtype.type(0)))
triton_lowering_rules[lax.not_p] = _not_lowering_rule

View File

@ -1532,7 +1532,7 @@ class PallasOpsTest(PallasTest):
# fmt: on
["float32", "float64"]
),
([lax.population_count, lax.clz], ["int32", "int64"]),
([lax.population_count, lax.clz, jnp.invert], ["int32", "int64"]),
]
@parameterized.named_parameters(