mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
BroadcastedIota needs integer type (fixes #728)
This commit is contained in:
parent
b1fd8e6eb6
commit
8f9e4b1260
@ -3700,7 +3700,7 @@ class _EyeConstant(xla.DeviceConstant):
|
||||
else:
|
||||
etype = xla_bridge.dtype_to_etype_exact(diag_const.dtype)
|
||||
etype = xla_bridge.dtype_to_etype(diag_const.dtype)
|
||||
iotas = [c.BroadcastedIota(onp.bool_, diag_const.shape, axis)
|
||||
iotas = [c.BroadcastedIota(onp.uint32, diag_const.shape, axis)
|
||||
for axis in diag_const.axes]
|
||||
eyes = [c.Eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])]
|
||||
return c.ConvertElementType(_reduce(c.And, eyes), etype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user