mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
docstring for shaped iota primitive
This commit is contained in:
parent
55d6daacfa
commit
a3483dbe32
@ -983,6 +983,52 @@ mlir.register_lowering(
|
||||
|
||||
|
||||
def iota_32x2_shape(shape):
|
||||
"""Reshaped ``uint64`` iota, as two parallel ``uint32`` arrays.
|
||||
|
||||
Setting aside representation, this function essentially computes the
|
||||
equivalent of::
|
||||
|
||||
jax.lax.iota(dtype=np.uint64, size=np.prod(shape)).reshape(shape)
|
||||
|
||||
However:
|
||||
|
||||
* It returns two parallel ``uint32`` arrays instead of one
|
||||
``uint64`` array. This renders it invariant under either setting of
|
||||
the system-wide ``jax_enable_x64`` configuration flag.
|
||||
|
||||
* It lowers in a way such that the compiler's automatic SPMD
|
||||
partitioner recognizes its partitionability.
|
||||
|
||||
For example::
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from jax import lax
|
||||
>>> from jax._src import prng
|
||||
|
||||
>>> prng.iota_32x2_shape((3, 4))
|
||||
[Array([[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]], dtype=uint32),
|
||||
Array([[ 0, 1, 2, 3],
|
||||
[ 4, 5, 6, 7],
|
||||
[ 8, 9, 10, 11]], dtype=uint32)]
|
||||
|
||||
>>> def reshaped_iota(shape):
|
||||
... return lax.iota(size=np.prod(shape), dtype=np.uint32).reshape(shape)
|
||||
...
|
||||
>>> reshaped_iota((3, 4))
|
||||
Array([[ 0, 1, 2, 3],
|
||||
[ 4, 5, 6, 7],
|
||||
[ 8, 9, 10, 11]], dtype=uint32)
|
||||
|
||||
Args:
|
||||
shape: the output shape
|
||||
|
||||
Returns:
|
||||
A pair of ``uint32`` arrays ``(counts_hi, counts_lo)``, both of
|
||||
shape ``shape``, representing the higher-order and lower-order 32
|
||||
bits of the 64 bit unsigned iota.
|
||||
"""
|
||||
if len(shape) == 0:
|
||||
return (jnp.zeros((), np.dtype('uint32')),) * 2
|
||||
return iota_32x2_shape_p.bind(shape=shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user