BroadcastedIota needs integer type (fixes #728)

This commit is contained in:
Matthew Johnson 2019-05-17 12:46:11 -07:00
parent b1fd8e6eb6
commit 8f9e4b1260

View File

@ -3700,7 +3700,7 @@ class _EyeConstant(xla.DeviceConstant):
else:
etype = xla_bridge.dtype_to_etype_exact(diag_const.dtype)
etype = xla_bridge.dtype_to_etype(diag_const.dtype)
iotas = [c.BroadcastedIota(onp.bool_, diag_const.shape, axis)
iotas = [c.BroadcastedIota(onp.uint32, diag_const.shape, axis)
for axis in diag_const.axes]
eyes = [c.Eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])]
return c.ConvertElementType(_reduce(c.And, eyes), etype)