Merge pull request #10379 from hawkinsp:onehot

PiperOrigin-RevId: 443122227
This commit is contained in:
jax authors 2022-04-20 09:57:05 -07:00
commit ea0233b995

View File

@ -383,8 +383,7 @@ def _one_hot(x: Array, num_classes: int, *,
lhs = lax.expand_dims(x, (axis,))
rhs_shape = [1] * x.ndim
rhs_shape.insert(output_pos_axis, num_classes)
rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype),
rhs_shape, (output_pos_axis,))
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis)
return jnp.asarray(lhs == rhs, dtype=dtype)
def one_hot(x: Array, num_classes: int, *,