Fix type specifications for bitwise ops. (#2054)

This commit is contained in:
Peter Hawkins 2020-01-23 11:53:55 -05:00 committed by GitHub
parent bb176d414b
commit a3de80201f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1649,6 +1649,7 @@ _bool = {onp.bool_}
_num = _int | _float | _complex
_any = _int | _float | _complex | _bool
_bool_or_int = _int | _bool
neg_p = standard_unop(_num, 'neg')
ad.deflinear(neg_p, lambda t: [neg(t)])
@ -1814,15 +1815,15 @@ def _pow_jvp_rhs(g, ans, x, y):
ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
_replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x)
not_p = standard_unop(_int | _bool, 'not')
not_p = standard_unop(_bool_or_int, 'not')
and_p = standard_naryop([_any, _any], 'and')
and_p = standard_naryop([_bool_or_int, _bool_or_int], 'and')
ad.defjvp_zero(and_p)
or_p = standard_naryop([_any, _any], 'or')
or_p = standard_naryop([_bool_or_int, _bool_or_int], 'or')
ad.defjvp_zero(or_p)
xor_p = standard_naryop([_any, _any], 'xor')
xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor')
ad.defjvp_zero(xor_p)
def _add_transpose(t, x, y):