mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix type specifications for bitwise ops. (#2054)
This commit is contained in:
parent
bb176d414b
commit
a3de80201f
@ -1649,6 +1649,7 @@ _bool = {onp.bool_}
|
||||
|
||||
_num = _int | _float | _complex
|
||||
_any = _int | _float | _complex | _bool
|
||||
_bool_or_int = _int | _bool
|
||||
|
||||
neg_p = standard_unop(_num, 'neg')
|
||||
ad.deflinear(neg_p, lambda t: [neg(t)])
|
||||
@ -1814,15 +1815,15 @@ def _pow_jvp_rhs(g, ans, x, y):
|
||||
ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
|
||||
_replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x)
|
||||
|
||||
not_p = standard_unop(_int | _bool, 'not')
|
||||
not_p = standard_unop(_bool_or_int, 'not')
|
||||
|
||||
and_p = standard_naryop([_any, _any], 'and')
|
||||
and_p = standard_naryop([_bool_or_int, _bool_or_int], 'and')
|
||||
ad.defjvp_zero(and_p)
|
||||
|
||||
or_p = standard_naryop([_any, _any], 'or')
|
||||
or_p = standard_naryop([_bool_or_int, _bool_or_int], 'or')
|
||||
ad.defjvp_zero(or_p)
|
||||
|
||||
xor_p = standard_naryop([_any, _any], 'xor')
|
||||
xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor')
|
||||
ad.defjvp_zero(xor_p)
|
||||
|
||||
def _add_transpose(t, x, y):
|
||||
|
Loading…
x
Reference in New Issue
Block a user