mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #10379 from hawkinsp:onehot
PiperOrigin-RevId: 443122227
This commit is contained in:
commit
ea0233b995
@ -383,8 +383,7 @@ def _one_hot(x: Array, num_classes: int, *,
|
||||
lhs = lax.expand_dims(x, (axis,))
|
||||
rhs_shape = [1] * x.ndim
|
||||
rhs_shape.insert(output_pos_axis, num_classes)
|
||||
rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype),
|
||||
rhs_shape, (output_pos_axis,))
|
||||
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis)
|
||||
return jnp.asarray(lhs == rhs, dtype=dtype)
|
||||
|
||||
def one_hot(x: Array, num_classes: int, *,
|
||||
|
Loading…
x
Reference in New Issue
Block a user