mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add bug number for integer dots.
This commit is contained in:
parent
15361c4dde
commit
bd389b7fcf
@ -457,7 +457,8 @@ def dot(lhs, rhs):
|
||||
Returns:
|
||||
An array containing the product.
|
||||
"""
|
||||
# XLA doesn't support integer dots, so emit a sum of products instead.
|
||||
# TODO(b/134526360): XLA doesn't support integer dots, so we emit a sum of
|
||||
# products instead.
|
||||
if onp.issubdtype(lhs.dtype, onp.integer):
|
||||
lhs_shape = onp.shape(lhs)
|
||||
lhs_ndim = len(lhs_shape)
|
||||
@ -491,7 +492,8 @@ def dot_general(lhs, rhs, dimension_numbers):
|
||||
contract_dims = tuple(map(tuple, contract_dims))
|
||||
batch_dims = tuple(map(tuple, batch_dims))
|
||||
if onp.issubdtype(lhs.dtype, onp.integer):
|
||||
# XLA doesn't support integer dots, so emit a sum of products instead.
|
||||
# TODO(b/134526360): XLA doesn't support integer dots, so we emit a sum of
|
||||
# products instead.
|
||||
lhs_contract_dims, rhs_contract_dims = contract_dims
|
||||
lhs_batch_dims, rhs_batch_dims = batch_dims
|
||||
lhs_noncontract_dims = tuple(sorted(
|
||||
|
Loading…
x
Reference in New Issue
Block a user