Add bug number for integer dots.

This commit is contained in:
Peter Hawkins 2019-06-06 19:04:11 -04:00
parent 15361c4dde
commit bd389b7fcf

View File

@ -457,7 +457,8 @@ def dot(lhs, rhs):
Returns:
An array containing the product.
"""
# XLA doesn't support integer dots, so emit a sum of products instead.
# TODO(b/134526360): XLA doesn't support integer dots, so we emit a sum of
# products instead.
if onp.issubdtype(lhs.dtype, onp.integer):
lhs_shape = onp.shape(lhs)
lhs_ndim = len(lhs_shape)
@ -491,7 +492,8 @@ def dot_general(lhs, rhs, dimension_numbers):
contract_dims = tuple(map(tuple, contract_dims))
batch_dims = tuple(map(tuple, batch_dims))
if onp.issubdtype(lhs.dtype, onp.integer):
# XLA doesn't support integer dots, so emit a sum of products instead.
# TODO(b/134526360): XLA doesn't support integer dots, so we emit a sum of
# products instead.
lhs_contract_dims, rhs_contract_dims = contract_dims
lhs_batch_dims, rhs_batch_dims = batch_dims
lhs_noncontract_dims = tuple(sorted(