Use jax.Array by default for doctests

PiperOrigin-RevId: 488719467
This commit is contained in:
Yash Katariya 2022-11-15 11:51:55 -08:00 committed by jax authors
parent eca12411e7
commit a419e1917a
23 changed files with 211 additions and 210 deletions

View File

@ -155,6 +155,7 @@ jobs:
- name: Test documentation
env:
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
JAX_ARRAY: 1
run: |
pytest -n 1 --tb=short docs
pytest -n 1 --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/gda_serialization --ignore=jax/collect_profile.py

View File

@ -181,7 +181,7 @@ equivalent to that of Python scalars, such as the integer scalar ``2`` in the fo
>>> x = jnp.arange(5, dtype='int8')
>>> 2 * x
DeviceArray([0, 2, 4, 6, 8], dtype=int8)
Array([0, 2, 4, 6, 8], dtype=int8)
JAX's weak type framework is designed to prevent unwanted type promotion within
binary operations between JAX values and values with no explicitly user-specified type,
@ -191,7 +191,7 @@ the expression above would lead to an implicit type promotion:
.. code-block:: python
>>> jnp.int32(2) * x
DeviceArray([0, 2, 4, 6, 8], dtype=int32)
Array([0, 2, 4, 6, 8], dtype=int32)
When used in JAX, Python scalars are sometimes promoted to :class:`~jax.numpy.DeviceArray`
objects, for example during JIT compilation. To maintain the desired promotion
@ -201,7 +201,7 @@ that can be seen in an array's string representation:
.. code-block:: python
>>> jnp.asarray(2)
DeviceArray(2, dtype=int32, weak_type=True)
Array(2, dtype=int32, weak_type=True)
If the ``dtype`` is specified explicitly, it will instead result in a standard
strongly-typed array value:
@ -209,4 +209,4 @@ strongly-typed array value:
.. code-block:: python
>>> jnp.asarray(2, dtype='int32')
DeviceArray(2, dtype=int32)
Array(2, dtype=int32)

View File

@ -189,7 +189,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
... return z
...
>>> jax.value_and_grad(g)(2.0)
(DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(-0.2556391, dtype=float32, weak_type=True))
(Array(0.78907233, dtype=float32, weak_type=True), Array(-0.2556391, dtype=float32, weak_type=True))
Here, the same value is produced whether or not the :func:`jax.checkpoint`
decorator is present. When the decorator is not present, the values

View File

@ -340,7 +340,7 @@ def jit(
... return x
>>>
>>> g(jnp.arange(4), 3)
DeviceArray([ 0, 1, 256, 6561], dtype=int32)
Array([ 0, 1, 256, 6561], dtype=int32)
"""
if abstracted_axes and not config.jax_dynamic_shapes:
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
@ -1431,14 +1431,14 @@ def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
>>> import jax.numpy as jnp
>>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])}
>>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.}))
{'c': {'a': {'a': DeviceArray([[[ 2., 0.], [ 0., 0.]],
[[ 0., 0.], [ 0., 12.]]], dtype=float32),
'b': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]],
[[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)},
'b': {'a': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]],
[[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32),
'b': DeviceArray([[[0. , 0. ], [0. , 0. ]],
[[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
{'c': {'a': {'a': Array([[[ 2., 0.], [ 0., 0.]],
[[ 0., 0.], [ 0., 12.]]], dtype=float32),
'b': Array([[[ 1. , 0. ], [ 0. , 0. ]],
[[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)},
'b': {'a': Array([[[ 1. , 0. ], [ 0. , 0. ]],
[[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32),
'b': Array([[[0. , 0. ], [0. , 0. ]],
[[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
Thus each leaf in the tree structure of ``jax.hessian(fun)(x)`` corresponds to
a leaf of ``fun(x)`` and a pair of leaves of ``x``. For each leaf in
@ -1623,13 +1623,13 @@ def vmap(fun: F,
(to keep it unmapped).
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.))
(DeviceArray([4., 5.], dtype=float32), 8.0)
(Array([4., 5.], dtype=float32), 8.0)
If the ``out_axes`` is specified for an unmapped result, the result is
broadcast across the mapped axis:
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.))
(DeviceArray([4., 5.], dtype=float32), DeviceArray([8., 8.], dtype=float32, weak_type=True))
(Array([4., 5.], dtype=float32), Array([8., 8.], dtype=float32, weak_type=True))
If the ``out_axes`` is specified for a mapped result, the result is transposed
accordingly.
@ -2499,7 +2499,7 @@ def linearize(fun: Callable, *primals) -> Tuple[Any, Callable]:
>>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.)
...
>>> jax.jvp(f, (2.,), (3.,))
(DeviceArray(3.26819, dtype=float32, weak_type=True), DeviceArray(-5.00753, dtype=float32, weak_type=True))
(Array(3.26819, dtype=float32, weak_type=True), Array(-5.00753, dtype=float32, weak_type=True))
>>> y, f_jvp = jax.linearize(f, 2.)
>>> print(y)
3.2681944
@ -2718,7 +2718,7 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
>>> f_transpose = jax.linear_transpose(f, scalar, scalar)
>>> f_transpose(1.0)
(DeviceArray(0.5, dtype=float32), DeviceArray(-0.5, dtype=float32))
(Array(0.5, dtype=float32), Array(-0.5, dtype=float32))
"""
reduce_axes = _ensure_str_tuple(reduce_axes)
primals_flat, in_tree = tree_flatten(primals)
@ -3046,7 +3046,7 @@ def device_get(x: Any):
If ``x`` is a pytree, then the individual buffers are copied in parallel.
Args:
x: An array, scalar, DeviceArray or (nested) standard Python container thereof
x: An array, scalar, Array or (nested) standard Python container thereof
representing the array to be transferred to host.
Returns:
@ -3054,7 +3054,7 @@ def device_get(x: Any):
value of ``x``.
Examples:
Passing a DeviceArray:
Passing a Array:
>>> import jax
>>> x = jax.numpy.array([1., 2., 3.])

View File

@ -1066,13 +1066,13 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
... return transposed
>>> div_add(9., 3.)
DeviceArray(12., dtype=float32, weak_type=True)
Array(12., dtype=float32, weak_type=True)
>>> transpose(partial(div_add, denom=3.), 1.)(18.) # custom
DeviceArray(24., dtype=float32, weak_type=True)
Array(24., dtype=float32, weak_type=True)
>>> transpose(lambda x: x + x / 3., 1.)(18.) # reference
DeviceArray(24., dtype=float32, weak_type=True)
Array(24., dtype=float32, weak_type=True)
The above definition of ``f`` illustrates the purpose of a residual
argument: division is linear in one of its inputs (the dividend

View File

@ -69,7 +69,7 @@ class ConcretizationTypeError(JAXTypeError):
... return x.min(axis)
>>> func(jnp.arange(4), 0)
DeviceArray(0, dtype=int32)
Array(0, dtype=int32)
Traced value used in control flow
Another case where this often arises is when a traced value is used in
@ -94,7 +94,7 @@ class ConcretizationTypeError(JAXTypeError):
... return jnp.where(x.sum() < y.sum(), x, y)
>>> func(jnp.ones(4), jnp.zeros(4))
DeviceArray([0., 0., 0., 0.], dtype=float32)
Array([0., 0., 0., 0.], dtype=float32)
For more complicated control flow including loops, see
:ref:`lax-control-flow`.
@ -140,7 +140,7 @@ class ConcretizationTypeError(JAXTypeError):
... return jnp.where(x > 1, x, 0).sum()
>>> func(jnp.arange(4))
DeviceArray(5, dtype=int32)
Array(5, dtype=int32)
To understand more subtleties having to do with tracers vs. regular values,
and concrete vs. abstract values, you may want to read
@ -209,7 +209,7 @@ class NonConcreteBooleanIndexError(JAXIndexError):
... return jnp.where(x > 0, x, 0).sum()
>>> sum_of_positive(jnp.arange(-5, 5))
DeviceArray(10, dtype=int32)
Array(10, dtype=int32)
This pattern of replacing boolean masking with three-argument
:func:`~jax.numpy.where` is a common solution to this sort of problem.
@ -236,7 +236,7 @@ class NonConcreteBooleanIndexError(JAXIndexError):
... return jnp.where(x < 0, 0, x)
>>> manual_clip(jnp.arange(-2, 2))
DeviceArray([0, 0, 0, 1], dtype=int32)
Array([0, 0, 0, 1], dtype=int32)
"""
def __init__(self, tracer: core.Tracer):
super().__init__(
@ -275,7 +275,7 @@ class TracerArrayConversionError(JAXTypeError):
... return jnp.sin(x)
>>> func(jnp.arange(4))
DeviceArray([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
Array([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
Indexing a numpy array with a tracer
If this error arises on a line that involves array indexing, it may be that
@ -302,7 +302,7 @@ class TracerArrayConversionError(JAXTypeError):
... return jnp.asarray(x)[i]
>>> func(0)
DeviceArray(0, dtype=int32)
Array(0, dtype=int32)
or by declaring the index as a static argument::
@ -311,7 +311,7 @@ class TracerArrayConversionError(JAXTypeError):
... return x[i]
>>> func(0)
DeviceArray(0, dtype=int32)
Array(0, dtype=int32)
To understand more subtleties having to do with tracers vs. regular values,
and concrete vs. abstract values, you may want to read
@ -354,15 +354,15 @@ class TracerIntegerConversionError(JAXTypeError):
... return np.split(x, 2, axis)
>>> func(np.arange(10), 0)
[DeviceArray([0, 1, 2, 3, 4], dtype=int32),
DeviceArray([5, 6, 7, 8, 9], dtype=int32)]
[Array([0, 1, 2, 3, 4], dtype=int32),
Array([5, 6, 7, 8, 9], dtype=int32)]
An alternative is to apply the transformation to a closure that encapsulates
the arguments to be protected, either manually as below or by using
:func:`functools.partial`::
>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4))
[DeviceArray([0, 1], dtype=int32), DeviceArray([2, 3], dtype=int32)]
[Array([0, 1], dtype=int32), Array([2, 3], dtype=int32)]
**Note a new closure is created at every invocation, which defeats the
compilation caching mechanism, which is why static_argnums is preferred.**
@ -395,7 +395,7 @@ class TracerIntegerConversionError(JAXTypeError):
... return jnp.array(L)[i]
>>> func(0)
DeviceArray(1, dtype=int32)
Array(1, dtype=int32)
or by declaring the index as a static argument::
@ -404,7 +404,7 @@ class TracerIntegerConversionError(JAXTypeError):
... return L[i]
>>> func(0)
DeviceArray(1, dtype=int32, weak_type=True)
Array(1, dtype=int32, weak_type=True)
To understand more subtleties having to do with tracers vs. regular values,
and concrete vs. abstract values, you may want to read
@ -502,7 +502,7 @@ class UnexpectedTracerError(JAXTypeError):
>>> y = not_side_effecting(x)
>>> outs.append(y)
>>> outs[0] + 1 # all good! no longer a leaked value.
DeviceArray(3, dtype=int32, weak_type=True)
Array(3, dtype=int32, weak_type=True)
Leak checker
As discussed in point 2 and 3 above, JAX shows a reconstructed stack trace

View File

@ -1782,7 +1782,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
Example 1: partial sums of an array of numbers:
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4))
DeviceArray([0, 1, 3, 6], dtype=int32)
Array([0, 1, 3, 6], dtype=int32)
Example 2: partial products of an array of matrices
@ -1794,7 +1794,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
Example 3: reversed partial sums of an array of numbers
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True)
DeviceArray([6, 6, 5, 3], dtype=int32)
Array([6, 6, 5, 3], dtype=int32)
.. [BLE1990] Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.",
Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon

View File

@ -244,9 +244,9 @@ def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array:
values which appear as zero in any operations. Consider this example::
>>> jnp.nextafter(0, 1) # denormal numbers are representable
DeviceArray(1.e-45, dtype=float32, weak_type=True)
Array(1.e-45, dtype=float32, weak_type=True)
>>> jnp.nextafter(0, 1) * 1 # but are flushed to zero
DeviceArray(0., dtype=float32, weak_type=True)
Array(0., dtype=float32, weak_type=True)
For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``.
"""
@ -853,18 +853,18 @@ def reshape(operand: ArrayLike, new_sizes: Shape,
>>> x = jnp.arange(6)
>>> y = reshape(x, (2, 3))
>>> y
DeviceArray([[0, 1, 2],
Array([[0, 1, 2],
[3, 4, 5]], dtype=int32)
Reshaping back to one dimension:
>>> reshape(y, (6,))
DeviceArray([0, 1, 2, 3, 4, 5], dtype=int32)
Array([0, 1, 2, 3, 4, 5], dtype=int32)
Reshaping to one dimension with permutation of dimensions:
>>> reshape(y, (6,), (1, 0))
DeviceArray([0, 3, 1, 4, 2, 5], dtype=int32)
Array([0, 3, 1, 4, 2, 5], dtype=int32)
"""
new_sizes = canonicalize_shape(new_sizes) # TODO
new_sizes = tuple(new_sizes)
@ -1266,13 +1266,13 @@ def stop_gradient(x: T) -> T:
For example:
>>> jax.grad(lambda x: x**2)(3.)
DeviceArray(6., dtype=float32, weak_type=True)
Array(6., dtype=float32, weak_type=True)
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
DeviceArray(0., dtype=float32, weak_type=True)
Array(0., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
DeviceArray(2., dtype=float32, weak_type=True)
Array(2., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
DeviceArray(0., dtype=float32, weak_type=True)
Array(0., dtype=float32, weak_type=True)
"""
def stop(x):
# only bind primitive on inexact dtypes, to avoid some staging

View File

@ -386,9 +386,9 @@ def axis_index(axis_name):
... return lax.axis_index('i')
...
>>> f(np.zeros(4))
ShardedDeviceArray([0, 1, 2, 3], dtype=int32)
Array([0, 1, 2, 3], dtype=int32)
>>> f(np.zeros(8))
ShardedDeviceArray([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
>>> @partial(jax.pmap, axis_name='i')
... @partial(jax.pmap, axis_name='j')
... def f(_):

View File

@ -84,21 +84,21 @@ def dynamic_slice(operand: Array, start_indices: Union[Array, Sequence[ArrayLike
>>> x = jnp.arange(12).reshape(3, 4)
>>> x
DeviceArray([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_slice(x, (1, 1), (2, 3))
DeviceArray([[ 5, 6, 7],
[ 9, 10, 11]], dtype=int32)
Array([[ 5, 6, 7],
[ 9, 10, 11]], dtype=int32)
Note the potentially surprising behavior for the case where the requested slice
overruns the bounds of the array; in this case the start index is adjusted to
return a slice of the requested size:
>>> dynamic_slice(x, (1, 1), (2, 4))
DeviceArray([[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
Array([[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
"""
start_indices = _dynamic_slice_indices(operand, start_indices)
if jax.config.jax_dynamic_shapes:
@ -129,25 +129,25 @@ def dynamic_update_slice(operand: Array, update: ArrayLike,
>>> x = jnp.zeros(6)
>>> y = jnp.ones(3)
>>> dynamic_update_slice(x, y, (2,))
DeviceArray([0., 0., 1., 1., 1., 0.], dtype=float32)
Array([0., 0., 1., 1., 1., 0.], dtype=float32)
If the update slice is too large to fit in the array, the start
index will be adjusted to make it fit
>>> dynamic_update_slice(x, y, (3,))
DeviceArray([0., 0., 0., 1., 1., 1.], dtype=float32)
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
>>> dynamic_update_slice(x, y, (5,))
DeviceArray([0., 0., 0., 1., 1., 1.], dtype=float32)
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
Here is an example of a two-dimensional slice update:
>>> x = jnp.zeros((4, 4))
>>> y = jnp.ones((2, 2))
>>> dynamic_update_slice(x, y, (1, 2))
DeviceArray([[0., 0., 0., 0.],
[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 0., 0.]], dtype=float32)
Array([[0., 0., 0., 0.],
[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 0., 0.]], dtype=float32)
"""
start_indices = _dynamic_slice_indices(operand, start_indices)
return dynamic_update_slice_p.bind(operand, update, *start_indices)
@ -1563,7 +1563,7 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums):
upper_bounds: core.Shape = tuple(operand.shape[i] - slice_sizes[i]
for i in dnums.scatter_dims_to_operand_dims)
# Stack upper_bounds into a DeviceArray[n]
# Stack upper_bounds into a Array[n]
upper_bound = lax.shape_as_value(upper_bounds)
upper_bound = lax.min(upper_bound, np.iinfo(indices.dtype).max)
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,

View File

@ -408,15 +408,15 @@ def one_hot(x: Array, num_classes: int, *,
``num_classes`` with the element at ``index`` set to one::
>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
DeviceArray([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)
Array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)
Indicies outside the range [0, num_classes) will be encoded as zeros::
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
DeviceArray([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
Array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
Args:
x: A tensor of indices.

View File

@ -55,8 +55,8 @@ def zeros(key: KeyArray,
>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
DeviceArray([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
Array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
"""
return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
@ -69,9 +69,9 @@ def ones(key: KeyArray,
>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32)
DeviceArray([[1., 1.],
[1., 1.],
[1., 1.]], dtype=float32)
Array([[1., 1.],
[1., 1.],
[1., 1.]], dtype=float32)
"""
return jnp.ones(shape, dtypes.canonicalize_dtype(dtype))
@ -87,8 +87,8 @@ def constant(value: Array,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.constant(-7)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
DeviceArray([[-7., -7., -7.],
[-7., -7., -7.]], dtype=float32)
Array([[-7., -7., -7.],
[-7., -7., -7.]], dtype=float32)
"""
def init(key: KeyArray,
shape: core.Shape,
@ -112,8 +112,8 @@ def uniform(scale: RealNumeric = 1e-2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.uniform(10.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[7.298188 , 8.691938 , 8.7230015],
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
Array([[7.298188 , 8.691938 , 8.7230015],
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
"""
def init(key: KeyArray,
shape: core.Shape,
@ -137,8 +137,8 @@ def normal(stddev: RealNumeric = 1e-2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.normal(5.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[ 3.0613258 , 5.6129413 , 5.6866574 ],
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
Array([[ 3.0613258 , 5.6129413 , 5.6866574 ],
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
"""
def init(key: KeyArray,
shape: core.Shape,
@ -319,8 +319,8 @@ def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[ 0.50350785, 0.8088631 , 0.81566876],
[-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
Array([[ 0.50350785, 0.8088631 , 0.81566876],
[-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
.. _Glorot uniform initializer: http://proceedings.mlr.press/v9/glorot10a.html
"""
@ -357,8 +357,8 @@ def glorot_normal(in_axis: Union[int, Sequence[int]] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[ 0.41770416, 0.75262755, 0.7619329 ],
[-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
Array([[ 0.41770416, 0.75262755, 0.7619329 ],
[-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
.. _Glorot normal initializer: http://proceedings.mlr.press/v9/glorot10a.html
"""
@ -394,8 +394,8 @@ def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[ 0.56293887, 0.90433645, 0.9119454 ],
[-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
Array([[ 0.56293887, 0.90433645, 0.9119454 ],
[-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
.. _Lecun uniform initializer: https://arxiv.org/abs/1706.02515
"""
@ -429,8 +429,8 @@ def lecun_normal(in_axis: Union[int, Sequence[int]] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[ 0.46700746, 0.8414632 , 0.8518669 ],
[-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
Array([[ 0.46700746, 0.8414632 , 0.8518669 ],
[-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
.. _Lecun normal initializer: https://arxiv.org/abs/1706.02515
"""
@ -465,8 +465,8 @@ def he_uniform(in_axis: Union[int, Sequence[int]] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.kaiming_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[ 0.79611576, 1.2789248 , 1.2896855 ],
[-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
Array([[ 0.79611576, 1.2789248 , 1.2896855 ],
[-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
.. _He uniform initializer: https://arxiv.org/abs/1502.01852
"""
@ -503,8 +503,8 @@ def he_normal(in_axis: Union[int, Sequence[int]] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.kaiming_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[ 0.6604483 , 1.1900088 , 1.2047218 ],
[-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
Array([[ 0.6604483 , 1.1900088 , 1.2047218 ],
[-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
.. _He normal initializer: https://arxiv.org/abs/1502.01852
"""
@ -536,8 +536,8 @@ def orthogonal(scale: RealNumeric = 1.0,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.orthogonal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
"""
def init(key: KeyArray,
shape: core.Shape,
@ -579,17 +579,17 @@ def delta_orthogonal(
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.delta_orthogonal()
>>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ]],
Array([[[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ]],
<BLANKLINE>
[[ 0.27858758, -0.7949833 , -0.53887904],
[ 0.9120717 , 0.04322892, 0.40774566],
[-0.30085585, -0.6050892 , 0.73712474]],
[[ 0.27858758, -0.7949833 , -0.53887904],
[ 0.9120717 , 0.04322892, 0.40774566],
[-0.30085585, -0.6050892 , 0.73712474]],
<BLANKLINE>
[[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ]]], dtype=float32)
[[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ]]], dtype=float32)
.. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393

View File

@ -70,20 +70,20 @@ class _Mgrid(_IndexGrid):
Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`:
>>> jnp.mgrid[0:4:1]
DeviceArray([0, 1, 2, 3], dtype=int32)
Array([0, 1, 2, 3], dtype=int32)
Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`:
>>> jnp.mgrid[0:1:4j]
DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
Array([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
Multiple slices can be used to create broadcasted grids of indices:
>>> jnp.mgrid[:2, :3]
DeviceArray([[[0, 0, 0],
[1, 1, 1]],
[[0, 1, 2],
[0, 1, 2]]], dtype=int32)
Array([[[0, 0, 0],
[1, 1, 1]],
[[0, 1, 2],
[0, 1, 2]]], dtype=int32)
"""
sparse = False
op_name = "mgrid"
@ -105,19 +105,19 @@ class _Ogrid(_IndexGrid):
Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`:
>>> jnp.ogrid[0:4:1]
DeviceArray([0, 1, 2, 3], dtype=int32)
Array([0, 1, 2, 3], dtype=int32)
Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`:
>>> jnp.ogrid[0:1:4j]
DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
Array([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
Multiple slices can be used to create sparse grids of indices:
>>> jnp.ogrid[:2, :3]
[DeviceArray([[0],
[1]], dtype=int32),
DeviceArray([[0, 1, 2]], dtype=int32)]
[Array([[0],
[1]], dtype=int32),
Array([[0, 1, 2]], dtype=int32)]
"""
sparse = True
op_name = "ogrid"
@ -213,56 +213,56 @@ class RClass(_AxisConcat):
Passing slices in the form ``[start:stop:step]`` generates ``jnp.arange`` objects:
>>> jnp.r_[-1:5:1, 0, 0, jnp.array([1,2,3])]
DeviceArray([-1, 0, 1, 2, 3, 4, 0, 0, 1, 2, 3], dtype=int32)
Array([-1, 0, 1, 2, 3, 4, 0, 0, 1, 2, 3], dtype=int32)
An imaginary value for ``step`` will create a ``jnp.linspace`` object instead,
which includes the right endpoint:
>>> jnp.r_[-1:1:6j, 0, jnp.array([1,2,3])]
DeviceArray([-1. , -0.6 , -0.20000002, 0.20000005,
0.6 , 1. , 0. , 1. ,
2. , 3. ], dtype=float32)
Array([-1. , -0.6 , -0.20000002, 0.20000005,
0.6 , 1. , 0. , 1. ,
2. , 3. ], dtype=float32)
Use a string directive of the form ``"axis,dims,trans1d"`` as the first argument to
specify concatenation axis, minimum number of dimensions, and the position of the
upgraded array's original dimensions in the resulting array's shape tuple:
>>> jnp.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, 2D output
DeviceArray([[1, 2, 3],
[4, 5, 6]], dtype=int32)
Array([[1, 2, 3],
[4, 5, 6]], dtype=int32)
>>> jnp.r_['0,2,0', [1,2,3], [4,5,6]] # push last input axis to the front
DeviceArray([[1],
[2],
[3],
[4],
[5],
[6]], dtype=int32)
Array([[1],
[2],
[3],
[4],
[5],
[6]], dtype=int32)
Negative values for ``trans1d`` offset the last axis towards the start
of the shape tuple:
>>> jnp.r_['0,2,-2', [1,2,3], [4,5,6]]
DeviceArray([[1],
[2],
[3],
[4],
[5],
[6]], dtype=int32)
Array([[1],
[2],
[3],
[4],
[5],
[6]], dtype=int32)
Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs
to create an array with an extra row or column axis, respectively:
>>> jnp.r_['r',[1,2,3], [4,5,6]]
DeviceArray([[1, 2, 3, 4, 5, 6]], dtype=int32)
Array([[1, 2, 3, 4, 5, 6]], dtype=int32)
>>> jnp.r_['c',[1,2,3], [4,5,6]]
DeviceArray([[1],
[2],
[3],
[4],
[5],
[6]], dtype=int32)
Array([[1],
[2],
[3],
[4],
[5],
[6]], dtype=int32)
For higher-dimensional inputs (``dim >= 2``), both directives ``"r"`` and ``"c"``
give the same result.
@ -288,32 +288,32 @@ class CClass(_AxisConcat):
>>> a = jnp.arange(6).reshape((2,3))
>>> jnp.c_[a,a]
DeviceArray([[0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5]], dtype=int32)
Array([[0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5]], dtype=int32)
Use a string directive of the form ``"axis:dims:trans1d"`` as the first argument to specify
concatenation axis, minimum number of dimensions, and the position of the upgraded array's
original dimensions in the resulting array's shape tuple:
>>> jnp.c_['0,2', [1,2,3], [4,5,6]]
DeviceArray([[1],
[2],
[3],
[4],
[5],
[6]], dtype=int32)
Array([[1],
[2],
[3],
[4],
[5],
[6]], dtype=int32)
>>> jnp.c_['0,2,-1', [1,2,3], [4,5,6]]
DeviceArray([[1, 2, 3],
[4, 5, 6]], dtype=int32)
Array([[1, 2, 3],
[4, 5, 6]], dtype=int32)
Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs
to create an array with inputs stacked along the last axis:
>>> jnp.c_['r',[1,2,3], [4,5,6]]
DeviceArray([[1, 4],
[2, 5],
[3, 6]], dtype=int32)
Array([[1, 4],
[2, 5],
[3, 6]], dtype=int32)
"""
axis = -1
ndmin = 2

View File

@ -277,12 +277,12 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array:
>>> val = jnp.uint32(0xFFFFFFFF)
>>> val.astype('int32')
DeviceArray(-1, dtype=int32)
Array(-1, dtype=int32)
This function clips to the values representable in the new type:
>>> _convert_and_clip_integer(val, 'int32')
DeviceArray(2147483647, dtype=int32)
Array(2147483647, dtype=int32)
"""
val = val if isinstance(val, ndarray) else asarray(val)
dtype = dtypes.canonicalize_dtype(dtype)
@ -5158,21 +5158,21 @@ class _IndexUpdateHelper:
--------
>>> x = jnp.arange(5.0)
>>> x
DeviceArray([0., 1., 2., 3., 4.], dtype=float32)
Array([0., 1., 2., 3., 4.], dtype=float32)
>>> x.at[2].add(10)
DeviceArray([ 0., 1., 12., 3., 4.], dtype=float32)
Array([ 0., 1., 12., 3., 4.], dtype=float32)
>>> x.at[10].add(10) # out-of-bounds indices are ignored
DeviceArray([0., 1., 2., 3., 4.], dtype=float32)
Array([0., 1., 2., 3., 4.], dtype=float32)
>>> x.at[20].add(10, mode='clip')
DeviceArray([ 0., 1., 2., 3., 14.], dtype=float32)
Array([ 0., 1., 2., 3., 14.], dtype=float32)
>>> x.at[2].get()
DeviceArray(2., dtype=float32)
Array(2., dtype=float32)
>>> x.at[20].get() # out-of-bounds indices clipped
DeviceArray(4., dtype=float32)
Array(4., dtype=float32)
>>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN
DeviceArray(nan, dtype=float32)
Array(nan, dtype=float32)
>>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value
DeviceArray(-1., dtype=float32)
Array(-1., dtype=float32)
"""
__slots__ = ("array",)

View File

@ -65,11 +65,11 @@ roots will be padded with NaN values:
# The default behavior matches numpy and strips leading zeros:
>>> jnp.roots(coeffs)
DeviceArray([-2.+0.j], dtype=complex64)
Array([-2.+0.j], dtype=complex64)
# With strip_zeros=False, extra roots are set to NaN:
>>> jnp.roots(coeffs, strip_zeros=False)
DeviceArray([-2. +0.j, nan+nanj], dtype=complex64)
Array([-2. +0.j, nan+nanj], dtype=complex64)
""",
extra_params="""
strip_zeros : bool, default=True

View File

@ -229,13 +229,13 @@ def segment_sum(data: Array,
>>> data = jnp.arange(5)
>>> segment_ids = jnp.array([0, 0, 1, 1, 2])
>>> segment_sum(data, segment_ids)
DeviceArray([1, 5, 4], dtype=int32)
Array([1, 5, 4], dtype=int32)
Using JIT requires static `num_segments`:
>>> from jax import jit
>>> jit(segment_sum, static_argnums=2)(data, segment_ids, 3)
DeviceArray([1, 5, 4], dtype=int32)
Array([1, 5, 4], dtype=int32)
"""
return _segment_update(
"segment_sum", data, segment_ids, lax.scatter_add, num_segments,
@ -285,13 +285,13 @@ def segment_prod(data: Array,
>>> data = jnp.arange(6)
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
>>> segment_prod(data, segment_ids)
DeviceArray([ 0, 6, 20], dtype=int32)
Array([ 0, 6, 20], dtype=int32)
Using JIT requires static `num_segments`:
>>> from jax import jit
>>> jit(segment_prod, static_argnums=2)(data, segment_ids, 3)
DeviceArray([ 0, 6, 20], dtype=int32)
Array([ 0, 6, 20], dtype=int32)
"""
return _segment_update(
"segment_prod", data, segment_ids, lax.scatter_mul, num_segments,
@ -340,13 +340,13 @@ def segment_max(data: Array,
>>> data = jnp.arange(6)
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
>>> segment_max(data, segment_ids)
DeviceArray([1, 3, 5], dtype=int32)
Array([1, 3, 5], dtype=int32)
Using JIT requires static `num_segments`:
>>> from jax import jit
>>> jit(segment_max, static_argnums=2)(data, segment_ids, 3)
DeviceArray([1, 3, 5], dtype=int32)
Array([1, 3, 5], dtype=int32)
"""
return _segment_update(
"segment_max", data, segment_ids, lax.scatter_max, num_segments,
@ -395,13 +395,13 @@ def segment_min(data: Array,
>>> data = jnp.arange(6)
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
>>> segment_min(data, segment_ids)
DeviceArray([0, 2, 4], dtype=int32)
Array([0, 2, 4], dtype=int32)
Using JIT requires static `num_segments`:
>>> from jax import jit
>>> jit(segment_min, static_argnums=2)(data, segment_ids, 3)
DeviceArray([0, 2, 4], dtype=int32)
Array([0, 2, 4], dtype=int32)
"""
return _segment_update(
"segment_min", data, segment_ids, lax.scatter_min, num_segments,

View File

@ -201,7 +201,7 @@ def _enable_cpp_named_sharding():
@pxla.use_cpp_class(_enable_cpp_named_sharding())
class NamedSharding(XLACompatibleSharding):
"""NamedSharding is a way to express ``Sharding``s using named axes.
r"""NamedSharding is a way to express ``Sharding``\s using named axes.
``Mesh`` and ``PartitionSpec`` can be used to express a ``Sharding`` with a name.

View File

@ -324,7 +324,7 @@ class Partial(functools.partial):
>>> import jax.numpy as jnp
>>> add_one = Partial(jnp.add, 1)
>>> add_one(2)
DeviceArray(3, dtype=int32, weak_type=True)
Array(3, dtype=int32, weak_type=True)
Pytree compatibility means that the resulting partial function can be passed
as an argument within transformed JAX functions, which is not possible with a
@ -336,13 +336,13 @@ class Partial(functools.partial):
... return f(*args)
...
>>> call_func(add_one, 2)
DeviceArray(3, dtype=int32, weak_type=True)
Array(3, dtype=int32, weak_type=True)
Passing zero arguments to ``Partial`` effectively wraps the original function,
making it a valid argument in JAX transformed functions:
>>> call_func(Partial(jnp.add), 1, 2)
DeviceArray(3, dtype=int32, weak_type=True)
Array(3, dtype=int32, weak_type=True)
Had we passed ``jnp.add`` to ``call_func`` directly, it would have resulted in a
``TypeError``.

View File

@ -415,8 +415,8 @@ def xmap(fun: Callable,
>>> xmap(jnp.vdot,
... in_axes=({0: 'left'}, {1: 'right'}),
... out_axes=['left', 'right', ...])(x, x.T)
DeviceArray([[ 30, 80],
[ 80, 255]], dtype=int32)
Array([[ 30, 80],
[ 80, 255]], dtype=int32)
Note that the contraction in the program is performed over the positional axes,
while named axes are just a convenient way to achieve batching. While this

View File

@ -45,21 +45,21 @@ Here is an example of creating a sparse array from a dense array:
Convert back to a dense array with the ``todense()`` method:
>>> M_sp.todense()
DeviceArray([[0., 1., 0., 2.],
[3., 0., 0., 0.],
[0., 0., 4., 0.]], dtype=float32)
Array([[0., 1., 0., 2.],
[3., 0., 0., 0.],
[0., 0., 4., 0.]], dtype=float32)
The BCOO format is a somewhat modified version of the standard COO format, and the dense
representation can be seen in the ``data`` and ``indices`` attributes:
>>> M_sp.data # Explicitly stored data
DeviceArray([1., 2., 3., 4.], dtype=float32)
Array([1., 2., 3., 4.], dtype=float32)
>>> M_sp.indices # Indices of the stored data
DeviceArray([[0, 1],
[0, 3],
[1, 0],
[2, 2]], dtype=int32)
Array([[0, 1],
[0, 3],
[1, 0],
[2, 2]], dtype=int32)
BCOO objects have familiar array-like attributes, as well as sparse-specific attributes:
@ -82,10 +82,10 @@ product:
>>> y = jnp.array([3., 6., 5.])
>>> M_sp.T @ y
DeviceArray([18., 3., 20., 6.], dtype=float32)
Array([18., 3., 20., 6.], dtype=float32)
>>> M.T @ y # Compare to dense version
DeviceArray([18., 3., 20., 6.], dtype=float32)
Array([18., 3., 20., 6.], dtype=float32)
BCOO objects are designed to be compatible with JAX transforms, including :func:`jax.jit`,
:func:`jax.vmap`, :func:`jax.grad`, and others. For example:
@ -96,7 +96,7 @@ BCOO objects are designed to be compatible with JAX transforms, including :func:
... return (M_sp.T @ y).sum()
...
>>> jit(grad(f))(y)
DeviceArray([3., 3., 4.], dtype=float32)
Array([3., 3., 4.], dtype=float32)
Note, however, that under normal circumstances :mod:`jax.numpy` and :mod:`jax.lax` functions
do not know how to handle sparse matrices, so attempting to compute things like
@ -114,7 +114,7 @@ Consider this function, which computes a more complicated result from a matrix a
... return 2 * jnp.dot(jnp.log1p(M.T), v) + 1
...
>>> f(M, y)
DeviceArray([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
Were we to pass a sparse matrix to this directly, it would result in an error, because ``jnp``
functions do not recognize sparse inputs. However, with :func:`sparsify`, we get a version of
@ -123,7 +123,7 @@ this function that does accept sparse matrices:
>>> f_sp = sparse.sparsify(f)
>>> f_sp(M_sp, y)
DeviceArray([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
Currently support for :func:`sparsify` is limited to a couple dozen primitives, including:

View File

@ -2331,17 +2331,17 @@ class BCOO(JAXSparse):
Examine the internal representation:
>>> M_sp.data
DeviceArray([2., 1., 4.], dtype=float32)
Array([2., 1., 4.], dtype=float32)
>>> M_sp.indices
DeviceArray([[0, 1],
[1, 0],
[1, 2]], dtype=int32)
Array([[0, 1],
[1, 0],
[1, 2]], dtype=int32)
Create a dense array from a sparse array:
>>> M_sp.todense()
DeviceArray([[0., 2., 0.],
[1., 0., 4.]], dtype=float32)
Array([[0., 2., 0.],
[1., 0., 4.]], dtype=float32)
Create a sparse array from COO data & indices:
@ -2353,9 +2353,9 @@ class BCOO(JAXSparse):
>>> mat
BCOO(float32[3, 3], nse=3)
>>> mat.todense()
DeviceArray([[1., 0., 0.],
[0., 3., 0.],
[0., 0., 5.]], dtype=float32)
Array([[1., 0., 0.],
[0., 3., 0.],
[0., 0., 5.]], dtype=float32)
"""
# Note: additional BCOO methods are defined in transform.py

View File

@ -34,16 +34,16 @@ For example:
... return -(jnp.sin(mat) @ vec)
...
>>> f(mat, vec)
DeviceArray([-1.2655463 , -0.52060574, -0.14522289, -0.10817424,
-0.15574613], dtype=float32)
Array([-1.2655463 , -0.52060574, -0.14522289, -0.10817424,
-0.15574613], dtype=float32)
>>> mat_sparse = BCOO.fromdense(mat)
>>> mat_sparse
BCOO(float32[5, 5], nse=8)
>>> sparsify(f)(mat_sparse, vec)
DeviceArray([-1.2655463 , -0.52060574, -0.14522289, -0.10817424,
-0.15574613], dtype=float32)
Array([-1.2655463 , -0.52060574, -0.14522289, -0.10817424,
-0.15574613], dtype=float32)
"""
import functools
@ -446,7 +446,7 @@ def sparsify(f, use_tracer=False):
>>> v = jnp.array([3, 4, 2])
>>> f(M, v)
DeviceArray([ 64, 82, 100, 118], dtype=int32)
Array([ 64, 82, 100, 118], dtype=int32)
"""
if use_tracer:
return _sparsify_with_tracer(f)

View File

@ -39,23 +39,23 @@ usually generated by the :py:func:`jax.random.PRNGKey` function::
>>> from jax import random
>>> key = random.PRNGKey(0)
>>> key
DeviceArray([0, 0], dtype=uint32)
Array([0, 0], dtype=uint32)
This key can then be used in any of JAX's random number generation routines::
>>> random.uniform(key)
DeviceArray(0.41845703, dtype=float32)
Array(0.41845703, dtype=float32)
Note that using a key does not modify it, so reusing the same key will lead to the same result::
>>> random.uniform(key)
DeviceArray(0.41845703, dtype=float32)
Array(0.41845703, dtype=float32)
If you need a new random number, you can use :meth:`jax.random.split` to generate new subkeys::
>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
DeviceArray(0.10536897, dtype=float32)
Array(0.10536897, dtype=float32)
Advanced
--------