mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
This commit is contained in:
parent
eca12411e7
commit
a419e1917a
1
.github/workflows/ci-build.yaml
vendored
1
.github/workflows/ci-build.yaml
vendored
@ -155,6 +155,7 @@ jobs:
|
||||
- name: Test documentation
|
||||
env:
|
||||
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
|
||||
JAX_ARRAY: 1
|
||||
run: |
|
||||
pytest -n 1 --tb=short docs
|
||||
pytest -n 1 --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/gda_serialization --ignore=jax/collect_profile.py
|
||||
|
@ -181,7 +181,7 @@ equivalent to that of Python scalars, such as the integer scalar ``2`` in the fo
|
||||
|
||||
>>> x = jnp.arange(5, dtype='int8')
|
||||
>>> 2 * x
|
||||
DeviceArray([0, 2, 4, 6, 8], dtype=int8)
|
||||
Array([0, 2, 4, 6, 8], dtype=int8)
|
||||
|
||||
JAX's weak type framework is designed to prevent unwanted type promotion within
|
||||
binary operations between JAX values and values with no explicitly user-specified type,
|
||||
@ -191,7 +191,7 @@ the expression above would lead to an implicit type promotion:
|
||||
.. code-block:: python
|
||||
|
||||
>>> jnp.int32(2) * x
|
||||
DeviceArray([0, 2, 4, 6, 8], dtype=int32)
|
||||
Array([0, 2, 4, 6, 8], dtype=int32)
|
||||
|
||||
When used in JAX, Python scalars are sometimes promoted to :class:`~jax.numpy.DeviceArray`
|
||||
objects, for example during JIT compilation. To maintain the desired promotion
|
||||
@ -201,7 +201,7 @@ that can be seen in an array's string representation:
|
||||
.. code-block:: python
|
||||
|
||||
>>> jnp.asarray(2)
|
||||
DeviceArray(2, dtype=int32, weak_type=True)
|
||||
Array(2, dtype=int32, weak_type=True)
|
||||
|
||||
If the ``dtype`` is specified explicitly, it will instead result in a standard
|
||||
strongly-typed array value:
|
||||
@ -209,4 +209,4 @@ strongly-typed array value:
|
||||
.. code-block:: python
|
||||
|
||||
>>> jnp.asarray(2, dtype='int32')
|
||||
DeviceArray(2, dtype=int32)
|
||||
Array(2, dtype=int32)
|
@ -189,7 +189,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True,
|
||||
... return z
|
||||
...
|
||||
>>> jax.value_and_grad(g)(2.0)
|
||||
(DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(-0.2556391, dtype=float32, weak_type=True))
|
||||
(Array(0.78907233, dtype=float32, weak_type=True), Array(-0.2556391, dtype=float32, weak_type=True))
|
||||
|
||||
Here, the same value is produced whether or not the :func:`jax.checkpoint`
|
||||
decorator is present. When the decorator is not present, the values
|
||||
|
@ -340,7 +340,7 @@ def jit(
|
||||
... return x
|
||||
>>>
|
||||
>>> g(jnp.arange(4), 3)
|
||||
DeviceArray([ 0, 1, 256, 6561], dtype=int32)
|
||||
Array([ 0, 1, 256, 6561], dtype=int32)
|
||||
"""
|
||||
if abstracted_axes and not config.jax_dynamic_shapes:
|
||||
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
|
||||
@ -1431,14 +1431,14 @@ def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
>>> import jax.numpy as jnp
|
||||
>>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])}
|
||||
>>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.}))
|
||||
{'c': {'a': {'a': DeviceArray([[[ 2., 0.], [ 0., 0.]],
|
||||
[[ 0., 0.], [ 0., 12.]]], dtype=float32),
|
||||
'b': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]],
|
||||
[[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)},
|
||||
'b': {'a': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]],
|
||||
[[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32),
|
||||
'b': DeviceArray([[[0. , 0. ], [0. , 0. ]],
|
||||
[[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
|
||||
{'c': {'a': {'a': Array([[[ 2., 0.], [ 0., 0.]],
|
||||
[[ 0., 0.], [ 0., 12.]]], dtype=float32),
|
||||
'b': Array([[[ 1. , 0. ], [ 0. , 0. ]],
|
||||
[[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)},
|
||||
'b': {'a': Array([[[ 1. , 0. ], [ 0. , 0. ]],
|
||||
[[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32),
|
||||
'b': Array([[[0. , 0. ], [0. , 0. ]],
|
||||
[[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
|
||||
|
||||
Thus each leaf in the tree structure of ``jax.hessian(fun)(x)`` corresponds to
|
||||
a leaf of ``fun(x)`` and a pair of leaves of ``x``. For each leaf in
|
||||
@ -1623,13 +1623,13 @@ def vmap(fun: F,
|
||||
(to keep it unmapped).
|
||||
|
||||
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.))
|
||||
(DeviceArray([4., 5.], dtype=float32), 8.0)
|
||||
(Array([4., 5.], dtype=float32), 8.0)
|
||||
|
||||
If the ``out_axes`` is specified for an unmapped result, the result is
|
||||
broadcast across the mapped axis:
|
||||
|
||||
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.))
|
||||
(DeviceArray([4., 5.], dtype=float32), DeviceArray([8., 8.], dtype=float32, weak_type=True))
|
||||
(Array([4., 5.], dtype=float32), Array([8., 8.], dtype=float32, weak_type=True))
|
||||
|
||||
If the ``out_axes`` is specified for a mapped result, the result is transposed
|
||||
accordingly.
|
||||
@ -2499,7 +2499,7 @@ def linearize(fun: Callable, *primals) -> Tuple[Any, Callable]:
|
||||
>>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.)
|
||||
...
|
||||
>>> jax.jvp(f, (2.,), (3.,))
|
||||
(DeviceArray(3.26819, dtype=float32, weak_type=True), DeviceArray(-5.00753, dtype=float32, weak_type=True))
|
||||
(Array(3.26819, dtype=float32, weak_type=True), Array(-5.00753, dtype=float32, weak_type=True))
|
||||
>>> y, f_jvp = jax.linearize(f, 2.)
|
||||
>>> print(y)
|
||||
3.2681944
|
||||
@ -2718,7 +2718,7 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
|
||||
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
|
||||
>>> f_transpose = jax.linear_transpose(f, scalar, scalar)
|
||||
>>> f_transpose(1.0)
|
||||
(DeviceArray(0.5, dtype=float32), DeviceArray(-0.5, dtype=float32))
|
||||
(Array(0.5, dtype=float32), Array(-0.5, dtype=float32))
|
||||
"""
|
||||
reduce_axes = _ensure_str_tuple(reduce_axes)
|
||||
primals_flat, in_tree = tree_flatten(primals)
|
||||
@ -3046,7 +3046,7 @@ def device_get(x: Any):
|
||||
If ``x`` is a pytree, then the individual buffers are copied in parallel.
|
||||
|
||||
Args:
|
||||
x: An array, scalar, DeviceArray or (nested) standard Python container thereof
|
||||
x: An array, scalar, Array or (nested) standard Python container thereof
|
||||
representing the array to be transferred to host.
|
||||
|
||||
Returns:
|
||||
@ -3054,7 +3054,7 @@ def device_get(x: Any):
|
||||
value of ``x``.
|
||||
|
||||
Examples:
|
||||
Passing a DeviceArray:
|
||||
Passing a Array:
|
||||
|
||||
>>> import jax
|
||||
>>> x = jax.numpy.array([1., 2., 3.])
|
||||
|
@ -1066,13 +1066,13 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
|
||||
... return transposed
|
||||
|
||||
>>> div_add(9., 3.)
|
||||
DeviceArray(12., dtype=float32, weak_type=True)
|
||||
Array(12., dtype=float32, weak_type=True)
|
||||
|
||||
>>> transpose(partial(div_add, denom=3.), 1.)(18.) # custom
|
||||
DeviceArray(24., dtype=float32, weak_type=True)
|
||||
Array(24., dtype=float32, weak_type=True)
|
||||
|
||||
>>> transpose(lambda x: x + x / 3., 1.)(18.) # reference
|
||||
DeviceArray(24., dtype=float32, weak_type=True)
|
||||
Array(24., dtype=float32, weak_type=True)
|
||||
|
||||
The above definition of ``f`` illustrates the purpose of a residual
|
||||
argument: division is linear in one of its inputs (the dividend
|
||||
|
@ -69,7 +69,7 @@ class ConcretizationTypeError(JAXTypeError):
|
||||
... return x.min(axis)
|
||||
|
||||
>>> func(jnp.arange(4), 0)
|
||||
DeviceArray(0, dtype=int32)
|
||||
Array(0, dtype=int32)
|
||||
|
||||
Traced value used in control flow
|
||||
Another case where this often arises is when a traced value is used in
|
||||
@ -94,7 +94,7 @@ class ConcretizationTypeError(JAXTypeError):
|
||||
... return jnp.where(x.sum() < y.sum(), x, y)
|
||||
|
||||
>>> func(jnp.ones(4), jnp.zeros(4))
|
||||
DeviceArray([0., 0., 0., 0.], dtype=float32)
|
||||
Array([0., 0., 0., 0.], dtype=float32)
|
||||
|
||||
For more complicated control flow including loops, see
|
||||
:ref:`lax-control-flow`.
|
||||
@ -140,7 +140,7 @@ class ConcretizationTypeError(JAXTypeError):
|
||||
... return jnp.where(x > 1, x, 0).sum()
|
||||
|
||||
>>> func(jnp.arange(4))
|
||||
DeviceArray(5, dtype=int32)
|
||||
Array(5, dtype=int32)
|
||||
|
||||
To understand more subtleties having to do with tracers vs. regular values,
|
||||
and concrete vs. abstract values, you may want to read
|
||||
@ -209,7 +209,7 @@ class NonConcreteBooleanIndexError(JAXIndexError):
|
||||
... return jnp.where(x > 0, x, 0).sum()
|
||||
|
||||
>>> sum_of_positive(jnp.arange(-5, 5))
|
||||
DeviceArray(10, dtype=int32)
|
||||
Array(10, dtype=int32)
|
||||
|
||||
This pattern of replacing boolean masking with three-argument
|
||||
:func:`~jax.numpy.where` is a common solution to this sort of problem.
|
||||
@ -236,7 +236,7 @@ class NonConcreteBooleanIndexError(JAXIndexError):
|
||||
... return jnp.where(x < 0, 0, x)
|
||||
|
||||
>>> manual_clip(jnp.arange(-2, 2))
|
||||
DeviceArray([0, 0, 0, 1], dtype=int32)
|
||||
Array([0, 0, 0, 1], dtype=int32)
|
||||
"""
|
||||
def __init__(self, tracer: core.Tracer):
|
||||
super().__init__(
|
||||
@ -275,7 +275,7 @@ class TracerArrayConversionError(JAXTypeError):
|
||||
... return jnp.sin(x)
|
||||
|
||||
>>> func(jnp.arange(4))
|
||||
DeviceArray([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
|
||||
Array([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
|
||||
|
||||
Indexing a numpy array with a tracer
|
||||
If this error arises on a line that involves array indexing, it may be that
|
||||
@ -302,7 +302,7 @@ class TracerArrayConversionError(JAXTypeError):
|
||||
... return jnp.asarray(x)[i]
|
||||
|
||||
>>> func(0)
|
||||
DeviceArray(0, dtype=int32)
|
||||
Array(0, dtype=int32)
|
||||
|
||||
or by declaring the index as a static argument::
|
||||
|
||||
@ -311,7 +311,7 @@ class TracerArrayConversionError(JAXTypeError):
|
||||
... return x[i]
|
||||
|
||||
>>> func(0)
|
||||
DeviceArray(0, dtype=int32)
|
||||
Array(0, dtype=int32)
|
||||
|
||||
To understand more subtleties having to do with tracers vs. regular values,
|
||||
and concrete vs. abstract values, you may want to read
|
||||
@ -354,15 +354,15 @@ class TracerIntegerConversionError(JAXTypeError):
|
||||
... return np.split(x, 2, axis)
|
||||
|
||||
>>> func(np.arange(10), 0)
|
||||
[DeviceArray([0, 1, 2, 3, 4], dtype=int32),
|
||||
DeviceArray([5, 6, 7, 8, 9], dtype=int32)]
|
||||
[Array([0, 1, 2, 3, 4], dtype=int32),
|
||||
Array([5, 6, 7, 8, 9], dtype=int32)]
|
||||
|
||||
An alternative is to apply the transformation to a closure that encapsulates
|
||||
the arguments to be protected, either manually as below or by using
|
||||
:func:`functools.partial`::
|
||||
|
||||
>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4))
|
||||
[DeviceArray([0, 1], dtype=int32), DeviceArray([2, 3], dtype=int32)]
|
||||
[Array([0, 1], dtype=int32), Array([2, 3], dtype=int32)]
|
||||
|
||||
**Note a new closure is created at every invocation, which defeats the
|
||||
compilation caching mechanism, which is why static_argnums is preferred.**
|
||||
@ -395,7 +395,7 @@ class TracerIntegerConversionError(JAXTypeError):
|
||||
... return jnp.array(L)[i]
|
||||
|
||||
>>> func(0)
|
||||
DeviceArray(1, dtype=int32)
|
||||
Array(1, dtype=int32)
|
||||
|
||||
or by declaring the index as a static argument::
|
||||
|
||||
@ -404,7 +404,7 @@ class TracerIntegerConversionError(JAXTypeError):
|
||||
... return L[i]
|
||||
|
||||
>>> func(0)
|
||||
DeviceArray(1, dtype=int32, weak_type=True)
|
||||
Array(1, dtype=int32, weak_type=True)
|
||||
|
||||
To understand more subtleties having to do with tracers vs. regular values,
|
||||
and concrete vs. abstract values, you may want to read
|
||||
@ -502,7 +502,7 @@ class UnexpectedTracerError(JAXTypeError):
|
||||
>>> y = not_side_effecting(x)
|
||||
>>> outs.append(y)
|
||||
>>> outs[0] + 1 # all good! no longer a leaked value.
|
||||
DeviceArray(3, dtype=int32, weak_type=True)
|
||||
Array(3, dtype=int32, weak_type=True)
|
||||
|
||||
Leak checker
|
||||
As discussed in point 2 and 3 above, JAX shows a reconstructed stack trace
|
||||
|
@ -1782,7 +1782,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
|
||||
Example 1: partial sums of an array of numbers:
|
||||
|
||||
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4))
|
||||
DeviceArray([0, 1, 3, 6], dtype=int32)
|
||||
Array([0, 1, 3, 6], dtype=int32)
|
||||
|
||||
Example 2: partial products of an array of matrices
|
||||
|
||||
@ -1794,7 +1794,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
|
||||
Example 3: reversed partial sums of an array of numbers
|
||||
|
||||
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True)
|
||||
DeviceArray([6, 6, 5, 3], dtype=int32)
|
||||
Array([6, 6, 5, 3], dtype=int32)
|
||||
|
||||
.. [BLE1990] Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.",
|
||||
Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon
|
||||
|
@ -244,9 +244,9 @@ def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
values which appear as zero in any operations. Consider this example::
|
||||
|
||||
>>> jnp.nextafter(0, 1) # denormal numbers are representable
|
||||
DeviceArray(1.e-45, dtype=float32, weak_type=True)
|
||||
Array(1.e-45, dtype=float32, weak_type=True)
|
||||
>>> jnp.nextafter(0, 1) * 1 # but are flushed to zero
|
||||
DeviceArray(0., dtype=float32, weak_type=True)
|
||||
Array(0., dtype=float32, weak_type=True)
|
||||
|
||||
For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``.
|
||||
"""
|
||||
@ -853,18 +853,18 @@ def reshape(operand: ArrayLike, new_sizes: Shape,
|
||||
>>> x = jnp.arange(6)
|
||||
>>> y = reshape(x, (2, 3))
|
||||
>>> y
|
||||
DeviceArray([[0, 1, 2],
|
||||
Array([[0, 1, 2],
|
||||
[3, 4, 5]], dtype=int32)
|
||||
|
||||
Reshaping back to one dimension:
|
||||
|
||||
>>> reshape(y, (6,))
|
||||
DeviceArray([0, 1, 2, 3, 4, 5], dtype=int32)
|
||||
Array([0, 1, 2, 3, 4, 5], dtype=int32)
|
||||
|
||||
Reshaping to one dimension with permutation of dimensions:
|
||||
|
||||
>>> reshape(y, (6,), (1, 0))
|
||||
DeviceArray([0, 3, 1, 4, 2, 5], dtype=int32)
|
||||
Array([0, 3, 1, 4, 2, 5], dtype=int32)
|
||||
"""
|
||||
new_sizes = canonicalize_shape(new_sizes) # TODO
|
||||
new_sizes = tuple(new_sizes)
|
||||
@ -1266,13 +1266,13 @@ def stop_gradient(x: T) -> T:
|
||||
For example:
|
||||
|
||||
>>> jax.grad(lambda x: x**2)(3.)
|
||||
DeviceArray(6., dtype=float32, weak_type=True)
|
||||
Array(6., dtype=float32, weak_type=True)
|
||||
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
|
||||
DeviceArray(0., dtype=float32, weak_type=True)
|
||||
Array(0., dtype=float32, weak_type=True)
|
||||
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
|
||||
DeviceArray(2., dtype=float32, weak_type=True)
|
||||
Array(2., dtype=float32, weak_type=True)
|
||||
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
|
||||
DeviceArray(0., dtype=float32, weak_type=True)
|
||||
Array(0., dtype=float32, weak_type=True)
|
||||
"""
|
||||
def stop(x):
|
||||
# only bind primitive on inexact dtypes, to avoid some staging
|
||||
|
@ -386,9 +386,9 @@ def axis_index(axis_name):
|
||||
... return lax.axis_index('i')
|
||||
...
|
||||
>>> f(np.zeros(4))
|
||||
ShardedDeviceArray([0, 1, 2, 3], dtype=int32)
|
||||
Array([0, 1, 2, 3], dtype=int32)
|
||||
>>> f(np.zeros(8))
|
||||
ShardedDeviceArray([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
|
||||
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
|
||||
>>> @partial(jax.pmap, axis_name='i')
|
||||
... @partial(jax.pmap, axis_name='j')
|
||||
... def f(_):
|
||||
|
@ -84,21 +84,21 @@ def dynamic_slice(operand: Array, start_indices: Union[Array, Sequence[ArrayLike
|
||||
|
||||
>>> x = jnp.arange(12).reshape(3, 4)
|
||||
>>> x
|
||||
DeviceArray([[ 0, 1, 2, 3],
|
||||
[ 4, 5, 6, 7],
|
||||
[ 8, 9, 10, 11]], dtype=int32)
|
||||
Array([[ 0, 1, 2, 3],
|
||||
[ 4, 5, 6, 7],
|
||||
[ 8, 9, 10, 11]], dtype=int32)
|
||||
|
||||
>>> dynamic_slice(x, (1, 1), (2, 3))
|
||||
DeviceArray([[ 5, 6, 7],
|
||||
[ 9, 10, 11]], dtype=int32)
|
||||
Array([[ 5, 6, 7],
|
||||
[ 9, 10, 11]], dtype=int32)
|
||||
|
||||
Note the potentially surprising behavior for the case where the requested slice
|
||||
overruns the bounds of the array; in this case the start index is adjusted to
|
||||
return a slice of the requested size:
|
||||
|
||||
>>> dynamic_slice(x, (1, 1), (2, 4))
|
||||
DeviceArray([[ 4, 5, 6, 7],
|
||||
[ 8, 9, 10, 11]], dtype=int32)
|
||||
Array([[ 4, 5, 6, 7],
|
||||
[ 8, 9, 10, 11]], dtype=int32)
|
||||
"""
|
||||
start_indices = _dynamic_slice_indices(operand, start_indices)
|
||||
if jax.config.jax_dynamic_shapes:
|
||||
@ -129,25 +129,25 @@ def dynamic_update_slice(operand: Array, update: ArrayLike,
|
||||
>>> x = jnp.zeros(6)
|
||||
>>> y = jnp.ones(3)
|
||||
>>> dynamic_update_slice(x, y, (2,))
|
||||
DeviceArray([0., 0., 1., 1., 1., 0.], dtype=float32)
|
||||
Array([0., 0., 1., 1., 1., 0.], dtype=float32)
|
||||
|
||||
If the update slice is too large to fit in the array, the start
|
||||
index will be adjusted to make it fit
|
||||
|
||||
>>> dynamic_update_slice(x, y, (3,))
|
||||
DeviceArray([0., 0., 0., 1., 1., 1.], dtype=float32)
|
||||
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
|
||||
>>> dynamic_update_slice(x, y, (5,))
|
||||
DeviceArray([0., 0., 0., 1., 1., 1.], dtype=float32)
|
||||
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
|
||||
|
||||
Here is an example of a two-dimensional slice update:
|
||||
|
||||
>>> x = jnp.zeros((4, 4))
|
||||
>>> y = jnp.ones((2, 2))
|
||||
>>> dynamic_update_slice(x, y, (1, 2))
|
||||
DeviceArray([[0., 0., 0., 0.],
|
||||
[0., 0., 1., 1.],
|
||||
[0., 0., 1., 1.],
|
||||
[0., 0., 0., 0.]], dtype=float32)
|
||||
Array([[0., 0., 0., 0.],
|
||||
[0., 0., 1., 1.],
|
||||
[0., 0., 1., 1.],
|
||||
[0., 0., 0., 0.]], dtype=float32)
|
||||
"""
|
||||
start_indices = _dynamic_slice_indices(operand, start_indices)
|
||||
return dynamic_update_slice_p.bind(operand, update, *start_indices)
|
||||
@ -1563,7 +1563,7 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums):
|
||||
|
||||
upper_bounds: core.Shape = tuple(operand.shape[i] - slice_sizes[i]
|
||||
for i in dnums.scatter_dims_to_operand_dims)
|
||||
# Stack upper_bounds into a DeviceArray[n]
|
||||
# Stack upper_bounds into a Array[n]
|
||||
upper_bound = lax.shape_as_value(upper_bounds)
|
||||
upper_bound = lax.min(upper_bound, np.iinfo(indices.dtype).max)
|
||||
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
|
||||
|
@ -408,15 +408,15 @@ def one_hot(x: Array, num_classes: int, *,
|
||||
``num_classes`` with the element at ``index`` set to one::
|
||||
|
||||
>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
|
||||
DeviceArray([[1., 0., 0.],
|
||||
[0., 1., 0.],
|
||||
[0., 0., 1.]], dtype=float32)
|
||||
Array([[1., 0., 0.],
|
||||
[0., 1., 0.],
|
||||
[0., 0., 1.]], dtype=float32)
|
||||
|
||||
Indicies outside the range [0, num_classes) will be encoded as zeros::
|
||||
|
||||
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
|
||||
DeviceArray([[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)
|
||||
Array([[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)
|
||||
|
||||
Args:
|
||||
x: A tensor of indices.
|
||||
|
@ -55,8 +55,8 @@ def zeros(key: KeyArray,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
DeviceArray([[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)
|
||||
Array([[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)
|
||||
"""
|
||||
return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
@ -69,9 +69,9 @@ def ones(key: KeyArray,
|
||||
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32)
|
||||
DeviceArray([[1., 1.],
|
||||
[1., 1.],
|
||||
[1., 1.]], dtype=float32)
|
||||
Array([[1., 1.],
|
||||
[1., 1.],
|
||||
[1., 1.]], dtype=float32)
|
||||
"""
|
||||
return jnp.ones(shape, dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
@ -87,8 +87,8 @@ def constant(value: Array,
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.constant(-7)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
|
||||
DeviceArray([[-7., -7., -7.],
|
||||
[-7., -7., -7.]], dtype=float32)
|
||||
Array([[-7., -7., -7.],
|
||||
[-7., -7., -7.]], dtype=float32)
|
||||
"""
|
||||
def init(key: KeyArray,
|
||||
shape: core.Shape,
|
||||
@ -112,8 +112,8 @@ def uniform(scale: RealNumeric = 1e-2,
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.uniform(10.0)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[7.298188 , 8.691938 , 8.7230015],
|
||||
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
|
||||
Array([[7.298188 , 8.691938 , 8.7230015],
|
||||
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
|
||||
"""
|
||||
def init(key: KeyArray,
|
||||
shape: core.Shape,
|
||||
@ -137,8 +137,8 @@ def normal(stddev: RealNumeric = 1e-2,
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.normal(5.0)
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[ 3.0613258 , 5.6129413 , 5.6866574 ],
|
||||
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
|
||||
Array([[ 3.0613258 , 5.6129413 , 5.6866574 ],
|
||||
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
|
||||
"""
|
||||
def init(key: KeyArray,
|
||||
shape: core.Shape,
|
||||
@ -319,8 +319,8 @@ def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2,
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.glorot_uniform()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[ 0.50350785, 0.8088631 , 0.81566876],
|
||||
[-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
|
||||
Array([[ 0.50350785, 0.8088631 , 0.81566876],
|
||||
[-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
|
||||
|
||||
.. _Glorot uniform initializer: http://proceedings.mlr.press/v9/glorot10a.html
|
||||
"""
|
||||
@ -357,8 +357,8 @@ def glorot_normal(in_axis: Union[int, Sequence[int]] = -2,
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.glorot_normal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[ 0.41770416, 0.75262755, 0.7619329 ],
|
||||
[-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
|
||||
Array([[ 0.41770416, 0.75262755, 0.7619329 ],
|
||||
[-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
|
||||
|
||||
.. _Glorot normal initializer: http://proceedings.mlr.press/v9/glorot10a.html
|
||||
"""
|
||||
@ -394,8 +394,8 @@ def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2,
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.lecun_uniform()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[ 0.56293887, 0.90433645, 0.9119454 ],
|
||||
[-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
|
||||
Array([[ 0.56293887, 0.90433645, 0.9119454 ],
|
||||
[-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
|
||||
|
||||
.. _Lecun uniform initializer: https://arxiv.org/abs/1706.02515
|
||||
"""
|
||||
@ -429,8 +429,8 @@ def lecun_normal(in_axis: Union[int, Sequence[int]] = -2,
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.lecun_normal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[ 0.46700746, 0.8414632 , 0.8518669 ],
|
||||
[-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
|
||||
Array([[ 0.46700746, 0.8414632 , 0.8518669 ],
|
||||
[-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
|
||||
|
||||
.. _Lecun normal initializer: https://arxiv.org/abs/1706.02515
|
||||
"""
|
||||
@ -465,8 +465,8 @@ def he_uniform(in_axis: Union[int, Sequence[int]] = -2,
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.kaiming_uniform()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[ 0.79611576, 1.2789248 , 1.2896855 ],
|
||||
[-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
|
||||
Array([[ 0.79611576, 1.2789248 , 1.2896855 ],
|
||||
[-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
|
||||
|
||||
.. _He uniform initializer: https://arxiv.org/abs/1502.01852
|
||||
"""
|
||||
@ -503,8 +503,8 @@ def he_normal(in_axis: Union[int, Sequence[int]] = -2,
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.kaiming_normal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[ 0.6604483 , 1.1900088 , 1.2047218 ],
|
||||
[-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
|
||||
Array([[ 0.6604483 , 1.1900088 , 1.2047218 ],
|
||||
[-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
|
||||
|
||||
.. _He normal initializer: https://arxiv.org/abs/1502.01852
|
||||
"""
|
||||
@ -536,8 +536,8 @@ def orthogonal(scale: RealNumeric = 1.0,
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.orthogonal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
|
||||
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
|
||||
Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
|
||||
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
|
||||
"""
|
||||
def init(key: KeyArray,
|
||||
shape: core.Shape,
|
||||
@ -579,17 +579,17 @@ def delta_orthogonal(
|
||||
>>> import jax, jax.numpy as jnp
|
||||
>>> initializer = jax.nn.initializers.delta_orthogonal()
|
||||
>>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32) # doctest: +SKIP
|
||||
DeviceArray([[[ 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. ]],
|
||||
Array([[[ 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. ]],
|
||||
<BLANKLINE>
|
||||
[[ 0.27858758, -0.7949833 , -0.53887904],
|
||||
[ 0.9120717 , 0.04322892, 0.40774566],
|
||||
[-0.30085585, -0.6050892 , 0.73712474]],
|
||||
[[ 0.27858758, -0.7949833 , -0.53887904],
|
||||
[ 0.9120717 , 0.04322892, 0.40774566],
|
||||
[-0.30085585, -0.6050892 , 0.73712474]],
|
||||
<BLANKLINE>
|
||||
[[ 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. ]]], dtype=float32)
|
||||
[[ 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. ],
|
||||
[ 0. , 0. , 0. ]]], dtype=float32)
|
||||
|
||||
|
||||
.. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393
|
||||
|
@ -70,20 +70,20 @@ class _Mgrid(_IndexGrid):
|
||||
Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`:
|
||||
|
||||
>>> jnp.mgrid[0:4:1]
|
||||
DeviceArray([0, 1, 2, 3], dtype=int32)
|
||||
Array([0, 1, 2, 3], dtype=int32)
|
||||
|
||||
Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`:
|
||||
|
||||
>>> jnp.mgrid[0:1:4j]
|
||||
DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
|
||||
Array([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
|
||||
|
||||
Multiple slices can be used to create broadcasted grids of indices:
|
||||
|
||||
>>> jnp.mgrid[:2, :3]
|
||||
DeviceArray([[[0, 0, 0],
|
||||
[1, 1, 1]],
|
||||
[[0, 1, 2],
|
||||
[0, 1, 2]]], dtype=int32)
|
||||
Array([[[0, 0, 0],
|
||||
[1, 1, 1]],
|
||||
[[0, 1, 2],
|
||||
[0, 1, 2]]], dtype=int32)
|
||||
"""
|
||||
sparse = False
|
||||
op_name = "mgrid"
|
||||
@ -105,19 +105,19 @@ class _Ogrid(_IndexGrid):
|
||||
Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`:
|
||||
|
||||
>>> jnp.ogrid[0:4:1]
|
||||
DeviceArray([0, 1, 2, 3], dtype=int32)
|
||||
Array([0, 1, 2, 3], dtype=int32)
|
||||
|
||||
Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`:
|
||||
|
||||
>>> jnp.ogrid[0:1:4j]
|
||||
DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
|
||||
Array([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
|
||||
|
||||
Multiple slices can be used to create sparse grids of indices:
|
||||
|
||||
>>> jnp.ogrid[:2, :3]
|
||||
[DeviceArray([[0],
|
||||
[1]], dtype=int32),
|
||||
DeviceArray([[0, 1, 2]], dtype=int32)]
|
||||
[Array([[0],
|
||||
[1]], dtype=int32),
|
||||
Array([[0, 1, 2]], dtype=int32)]
|
||||
"""
|
||||
sparse = True
|
||||
op_name = "ogrid"
|
||||
@ -213,56 +213,56 @@ class RClass(_AxisConcat):
|
||||
Passing slices in the form ``[start:stop:step]`` generates ``jnp.arange`` objects:
|
||||
|
||||
>>> jnp.r_[-1:5:1, 0, 0, jnp.array([1,2,3])]
|
||||
DeviceArray([-1, 0, 1, 2, 3, 4, 0, 0, 1, 2, 3], dtype=int32)
|
||||
Array([-1, 0, 1, 2, 3, 4, 0, 0, 1, 2, 3], dtype=int32)
|
||||
|
||||
An imaginary value for ``step`` will create a ``jnp.linspace`` object instead,
|
||||
which includes the right endpoint:
|
||||
|
||||
>>> jnp.r_[-1:1:6j, 0, jnp.array([1,2,3])]
|
||||
DeviceArray([-1. , -0.6 , -0.20000002, 0.20000005,
|
||||
0.6 , 1. , 0. , 1. ,
|
||||
2. , 3. ], dtype=float32)
|
||||
Array([-1. , -0.6 , -0.20000002, 0.20000005,
|
||||
0.6 , 1. , 0. , 1. ,
|
||||
2. , 3. ], dtype=float32)
|
||||
|
||||
Use a string directive of the form ``"axis,dims,trans1d"`` as the first argument to
|
||||
specify concatenation axis, minimum number of dimensions, and the position of the
|
||||
upgraded array's original dimensions in the resulting array's shape tuple:
|
||||
|
||||
>>> jnp.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, 2D output
|
||||
DeviceArray([[1, 2, 3],
|
||||
[4, 5, 6]], dtype=int32)
|
||||
Array([[1, 2, 3],
|
||||
[4, 5, 6]], dtype=int32)
|
||||
|
||||
>>> jnp.r_['0,2,0', [1,2,3], [4,5,6]] # push last input axis to the front
|
||||
DeviceArray([[1],
|
||||
[2],
|
||||
[3],
|
||||
[4],
|
||||
[5],
|
||||
[6]], dtype=int32)
|
||||
Array([[1],
|
||||
[2],
|
||||
[3],
|
||||
[4],
|
||||
[5],
|
||||
[6]], dtype=int32)
|
||||
|
||||
Negative values for ``trans1d`` offset the last axis towards the start
|
||||
of the shape tuple:
|
||||
|
||||
>>> jnp.r_['0,2,-2', [1,2,3], [4,5,6]]
|
||||
DeviceArray([[1],
|
||||
[2],
|
||||
[3],
|
||||
[4],
|
||||
[5],
|
||||
[6]], dtype=int32)
|
||||
Array([[1],
|
||||
[2],
|
||||
[3],
|
||||
[4],
|
||||
[5],
|
||||
[6]], dtype=int32)
|
||||
|
||||
Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs
|
||||
to create an array with an extra row or column axis, respectively:
|
||||
|
||||
>>> jnp.r_['r',[1,2,3], [4,5,6]]
|
||||
DeviceArray([[1, 2, 3, 4, 5, 6]], dtype=int32)
|
||||
Array([[1, 2, 3, 4, 5, 6]], dtype=int32)
|
||||
|
||||
>>> jnp.r_['c',[1,2,3], [4,5,6]]
|
||||
DeviceArray([[1],
|
||||
[2],
|
||||
[3],
|
||||
[4],
|
||||
[5],
|
||||
[6]], dtype=int32)
|
||||
Array([[1],
|
||||
[2],
|
||||
[3],
|
||||
[4],
|
||||
[5],
|
||||
[6]], dtype=int32)
|
||||
|
||||
For higher-dimensional inputs (``dim >= 2``), both directives ``"r"`` and ``"c"``
|
||||
give the same result.
|
||||
@ -288,32 +288,32 @@ class CClass(_AxisConcat):
|
||||
|
||||
>>> a = jnp.arange(6).reshape((2,3))
|
||||
>>> jnp.c_[a,a]
|
||||
DeviceArray([[0, 1, 2, 0, 1, 2],
|
||||
[3, 4, 5, 3, 4, 5]], dtype=int32)
|
||||
Array([[0, 1, 2, 0, 1, 2],
|
||||
[3, 4, 5, 3, 4, 5]], dtype=int32)
|
||||
|
||||
Use a string directive of the form ``"axis:dims:trans1d"`` as the first argument to specify
|
||||
concatenation axis, minimum number of dimensions, and the position of the upgraded array's
|
||||
original dimensions in the resulting array's shape tuple:
|
||||
|
||||
>>> jnp.c_['0,2', [1,2,3], [4,5,6]]
|
||||
DeviceArray([[1],
|
||||
[2],
|
||||
[3],
|
||||
[4],
|
||||
[5],
|
||||
[6]], dtype=int32)
|
||||
Array([[1],
|
||||
[2],
|
||||
[3],
|
||||
[4],
|
||||
[5],
|
||||
[6]], dtype=int32)
|
||||
|
||||
>>> jnp.c_['0,2,-1', [1,2,3], [4,5,6]]
|
||||
DeviceArray([[1, 2, 3],
|
||||
[4, 5, 6]], dtype=int32)
|
||||
Array([[1, 2, 3],
|
||||
[4, 5, 6]], dtype=int32)
|
||||
|
||||
Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs
|
||||
to create an array with inputs stacked along the last axis:
|
||||
|
||||
>>> jnp.c_['r',[1,2,3], [4,5,6]]
|
||||
DeviceArray([[1, 4],
|
||||
[2, 5],
|
||||
[3, 6]], dtype=int32)
|
||||
Array([[1, 4],
|
||||
[2, 5],
|
||||
[3, 6]], dtype=int32)
|
||||
"""
|
||||
axis = -1
|
||||
ndmin = 2
|
||||
|
@ -277,12 +277,12 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array:
|
||||
|
||||
>>> val = jnp.uint32(0xFFFFFFFF)
|
||||
>>> val.astype('int32')
|
||||
DeviceArray(-1, dtype=int32)
|
||||
Array(-1, dtype=int32)
|
||||
|
||||
This function clips to the values representable in the new type:
|
||||
|
||||
>>> _convert_and_clip_integer(val, 'int32')
|
||||
DeviceArray(2147483647, dtype=int32)
|
||||
Array(2147483647, dtype=int32)
|
||||
"""
|
||||
val = val if isinstance(val, ndarray) else asarray(val)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
@ -5158,21 +5158,21 @@ class _IndexUpdateHelper:
|
||||
--------
|
||||
>>> x = jnp.arange(5.0)
|
||||
>>> x
|
||||
DeviceArray([0., 1., 2., 3., 4.], dtype=float32)
|
||||
Array([0., 1., 2., 3., 4.], dtype=float32)
|
||||
>>> x.at[2].add(10)
|
||||
DeviceArray([ 0., 1., 12., 3., 4.], dtype=float32)
|
||||
Array([ 0., 1., 12., 3., 4.], dtype=float32)
|
||||
>>> x.at[10].add(10) # out-of-bounds indices are ignored
|
||||
DeviceArray([0., 1., 2., 3., 4.], dtype=float32)
|
||||
Array([0., 1., 2., 3., 4.], dtype=float32)
|
||||
>>> x.at[20].add(10, mode='clip')
|
||||
DeviceArray([ 0., 1., 2., 3., 14.], dtype=float32)
|
||||
Array([ 0., 1., 2., 3., 14.], dtype=float32)
|
||||
>>> x.at[2].get()
|
||||
DeviceArray(2., dtype=float32)
|
||||
Array(2., dtype=float32)
|
||||
>>> x.at[20].get() # out-of-bounds indices clipped
|
||||
DeviceArray(4., dtype=float32)
|
||||
Array(4., dtype=float32)
|
||||
>>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN
|
||||
DeviceArray(nan, dtype=float32)
|
||||
Array(nan, dtype=float32)
|
||||
>>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value
|
||||
DeviceArray(-1., dtype=float32)
|
||||
Array(-1., dtype=float32)
|
||||
"""
|
||||
__slots__ = ("array",)
|
||||
|
||||
|
@ -65,11 +65,11 @@ roots will be padded with NaN values:
|
||||
|
||||
# The default behavior matches numpy and strips leading zeros:
|
||||
>>> jnp.roots(coeffs)
|
||||
DeviceArray([-2.+0.j], dtype=complex64)
|
||||
Array([-2.+0.j], dtype=complex64)
|
||||
|
||||
# With strip_zeros=False, extra roots are set to NaN:
|
||||
>>> jnp.roots(coeffs, strip_zeros=False)
|
||||
DeviceArray([-2. +0.j, nan+nanj], dtype=complex64)
|
||||
Array([-2. +0.j, nan+nanj], dtype=complex64)
|
||||
""",
|
||||
extra_params="""
|
||||
strip_zeros : bool, default=True
|
||||
|
@ -229,13 +229,13 @@ def segment_sum(data: Array,
|
||||
>>> data = jnp.arange(5)
|
||||
>>> segment_ids = jnp.array([0, 0, 1, 1, 2])
|
||||
>>> segment_sum(data, segment_ids)
|
||||
DeviceArray([1, 5, 4], dtype=int32)
|
||||
Array([1, 5, 4], dtype=int32)
|
||||
|
||||
Using JIT requires static `num_segments`:
|
||||
|
||||
>>> from jax import jit
|
||||
>>> jit(segment_sum, static_argnums=2)(data, segment_ids, 3)
|
||||
DeviceArray([1, 5, 4], dtype=int32)
|
||||
Array([1, 5, 4], dtype=int32)
|
||||
"""
|
||||
return _segment_update(
|
||||
"segment_sum", data, segment_ids, lax.scatter_add, num_segments,
|
||||
@ -285,13 +285,13 @@ def segment_prod(data: Array,
|
||||
>>> data = jnp.arange(6)
|
||||
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
|
||||
>>> segment_prod(data, segment_ids)
|
||||
DeviceArray([ 0, 6, 20], dtype=int32)
|
||||
Array([ 0, 6, 20], dtype=int32)
|
||||
|
||||
Using JIT requires static `num_segments`:
|
||||
|
||||
>>> from jax import jit
|
||||
>>> jit(segment_prod, static_argnums=2)(data, segment_ids, 3)
|
||||
DeviceArray([ 0, 6, 20], dtype=int32)
|
||||
Array([ 0, 6, 20], dtype=int32)
|
||||
"""
|
||||
return _segment_update(
|
||||
"segment_prod", data, segment_ids, lax.scatter_mul, num_segments,
|
||||
@ -340,13 +340,13 @@ def segment_max(data: Array,
|
||||
>>> data = jnp.arange(6)
|
||||
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
|
||||
>>> segment_max(data, segment_ids)
|
||||
DeviceArray([1, 3, 5], dtype=int32)
|
||||
Array([1, 3, 5], dtype=int32)
|
||||
|
||||
Using JIT requires static `num_segments`:
|
||||
|
||||
>>> from jax import jit
|
||||
>>> jit(segment_max, static_argnums=2)(data, segment_ids, 3)
|
||||
DeviceArray([1, 3, 5], dtype=int32)
|
||||
Array([1, 3, 5], dtype=int32)
|
||||
"""
|
||||
return _segment_update(
|
||||
"segment_max", data, segment_ids, lax.scatter_max, num_segments,
|
||||
@ -395,13 +395,13 @@ def segment_min(data: Array,
|
||||
>>> data = jnp.arange(6)
|
||||
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
|
||||
>>> segment_min(data, segment_ids)
|
||||
DeviceArray([0, 2, 4], dtype=int32)
|
||||
Array([0, 2, 4], dtype=int32)
|
||||
|
||||
Using JIT requires static `num_segments`:
|
||||
|
||||
>>> from jax import jit
|
||||
>>> jit(segment_min, static_argnums=2)(data, segment_ids, 3)
|
||||
DeviceArray([0, 2, 4], dtype=int32)
|
||||
Array([0, 2, 4], dtype=int32)
|
||||
"""
|
||||
return _segment_update(
|
||||
"segment_min", data, segment_ids, lax.scatter_min, num_segments,
|
||||
|
@ -201,7 +201,7 @@ def _enable_cpp_named_sharding():
|
||||
|
||||
@pxla.use_cpp_class(_enable_cpp_named_sharding())
|
||||
class NamedSharding(XLACompatibleSharding):
|
||||
"""NamedSharding is a way to express ``Sharding``s using named axes.
|
||||
r"""NamedSharding is a way to express ``Sharding``\s using named axes.
|
||||
|
||||
``Mesh`` and ``PartitionSpec`` can be used to express a ``Sharding`` with a name.
|
||||
|
||||
|
@ -324,7 +324,7 @@ class Partial(functools.partial):
|
||||
>>> import jax.numpy as jnp
|
||||
>>> add_one = Partial(jnp.add, 1)
|
||||
>>> add_one(2)
|
||||
DeviceArray(3, dtype=int32, weak_type=True)
|
||||
Array(3, dtype=int32, weak_type=True)
|
||||
|
||||
Pytree compatibility means that the resulting partial function can be passed
|
||||
as an argument within transformed JAX functions, which is not possible with a
|
||||
@ -336,13 +336,13 @@ class Partial(functools.partial):
|
||||
... return f(*args)
|
||||
...
|
||||
>>> call_func(add_one, 2)
|
||||
DeviceArray(3, dtype=int32, weak_type=True)
|
||||
Array(3, dtype=int32, weak_type=True)
|
||||
|
||||
Passing zero arguments to ``Partial`` effectively wraps the original function,
|
||||
making it a valid argument in JAX transformed functions:
|
||||
|
||||
>>> call_func(Partial(jnp.add), 1, 2)
|
||||
DeviceArray(3, dtype=int32, weak_type=True)
|
||||
Array(3, dtype=int32, weak_type=True)
|
||||
|
||||
Had we passed ``jnp.add`` to ``call_func`` directly, it would have resulted in a
|
||||
``TypeError``.
|
||||
|
@ -415,8 +415,8 @@ def xmap(fun: Callable,
|
||||
>>> xmap(jnp.vdot,
|
||||
... in_axes=({0: 'left'}, {1: 'right'}),
|
||||
... out_axes=['left', 'right', ...])(x, x.T)
|
||||
DeviceArray([[ 30, 80],
|
||||
[ 80, 255]], dtype=int32)
|
||||
Array([[ 30, 80],
|
||||
[ 80, 255]], dtype=int32)
|
||||
|
||||
Note that the contraction in the program is performed over the positional axes,
|
||||
while named axes are just a convenient way to achieve batching. While this
|
||||
|
@ -45,21 +45,21 @@ Here is an example of creating a sparse array from a dense array:
|
||||
Convert back to a dense array with the ``todense()`` method:
|
||||
|
||||
>>> M_sp.todense()
|
||||
DeviceArray([[0., 1., 0., 2.],
|
||||
[3., 0., 0., 0.],
|
||||
[0., 0., 4., 0.]], dtype=float32)
|
||||
Array([[0., 1., 0., 2.],
|
||||
[3., 0., 0., 0.],
|
||||
[0., 0., 4., 0.]], dtype=float32)
|
||||
|
||||
The BCOO format is a somewhat modified version of the standard COO format, and the dense
|
||||
representation can be seen in the ``data`` and ``indices`` attributes:
|
||||
|
||||
>>> M_sp.data # Explicitly stored data
|
||||
DeviceArray([1., 2., 3., 4.], dtype=float32)
|
||||
Array([1., 2., 3., 4.], dtype=float32)
|
||||
|
||||
>>> M_sp.indices # Indices of the stored data
|
||||
DeviceArray([[0, 1],
|
||||
[0, 3],
|
||||
[1, 0],
|
||||
[2, 2]], dtype=int32)
|
||||
Array([[0, 1],
|
||||
[0, 3],
|
||||
[1, 0],
|
||||
[2, 2]], dtype=int32)
|
||||
|
||||
BCOO objects have familiar array-like attributes, as well as sparse-specific attributes:
|
||||
|
||||
@ -82,10 +82,10 @@ product:
|
||||
>>> y = jnp.array([3., 6., 5.])
|
||||
|
||||
>>> M_sp.T @ y
|
||||
DeviceArray([18., 3., 20., 6.], dtype=float32)
|
||||
Array([18., 3., 20., 6.], dtype=float32)
|
||||
|
||||
>>> M.T @ y # Compare to dense version
|
||||
DeviceArray([18., 3., 20., 6.], dtype=float32)
|
||||
Array([18., 3., 20., 6.], dtype=float32)
|
||||
|
||||
BCOO objects are designed to be compatible with JAX transforms, including :func:`jax.jit`,
|
||||
:func:`jax.vmap`, :func:`jax.grad`, and others. For example:
|
||||
@ -96,7 +96,7 @@ BCOO objects are designed to be compatible with JAX transforms, including :func:
|
||||
... return (M_sp.T @ y).sum()
|
||||
...
|
||||
>>> jit(grad(f))(y)
|
||||
DeviceArray([3., 3., 4.], dtype=float32)
|
||||
Array([3., 3., 4.], dtype=float32)
|
||||
|
||||
Note, however, that under normal circumstances :mod:`jax.numpy` and :mod:`jax.lax` functions
|
||||
do not know how to handle sparse matrices, so attempting to compute things like
|
||||
@ -114,7 +114,7 @@ Consider this function, which computes a more complicated result from a matrix a
|
||||
... return 2 * jnp.dot(jnp.log1p(M.T), v) + 1
|
||||
...
|
||||
>>> f(M, y)
|
||||
DeviceArray([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
|
||||
Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
|
||||
|
||||
Were we to pass a sparse matrix to this directly, it would result in an error, because ``jnp``
|
||||
functions do not recognize sparse inputs. However, with :func:`sparsify`, we get a version of
|
||||
@ -123,7 +123,7 @@ this function that does accept sparse matrices:
|
||||
>>> f_sp = sparse.sparsify(f)
|
||||
|
||||
>>> f_sp(M_sp, y)
|
||||
DeviceArray([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
|
||||
Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
|
||||
|
||||
Currently support for :func:`sparsify` is limited to a couple dozen primitives, including:
|
||||
|
||||
|
@ -2331,17 +2331,17 @@ class BCOO(JAXSparse):
|
||||
Examine the internal representation:
|
||||
|
||||
>>> M_sp.data
|
||||
DeviceArray([2., 1., 4.], dtype=float32)
|
||||
Array([2., 1., 4.], dtype=float32)
|
||||
>>> M_sp.indices
|
||||
DeviceArray([[0, 1],
|
||||
[1, 0],
|
||||
[1, 2]], dtype=int32)
|
||||
Array([[0, 1],
|
||||
[1, 0],
|
||||
[1, 2]], dtype=int32)
|
||||
|
||||
Create a dense array from a sparse array:
|
||||
|
||||
>>> M_sp.todense()
|
||||
DeviceArray([[0., 2., 0.],
|
||||
[1., 0., 4.]], dtype=float32)
|
||||
Array([[0., 2., 0.],
|
||||
[1., 0., 4.]], dtype=float32)
|
||||
|
||||
Create a sparse array from COO data & indices:
|
||||
|
||||
@ -2353,9 +2353,9 @@ class BCOO(JAXSparse):
|
||||
>>> mat
|
||||
BCOO(float32[3, 3], nse=3)
|
||||
>>> mat.todense()
|
||||
DeviceArray([[1., 0., 0.],
|
||||
[0., 3., 0.],
|
||||
[0., 0., 5.]], dtype=float32)
|
||||
Array([[1., 0., 0.],
|
||||
[0., 3., 0.],
|
||||
[0., 0., 5.]], dtype=float32)
|
||||
"""
|
||||
# Note: additional BCOO methods are defined in transform.py
|
||||
|
||||
|
@ -34,16 +34,16 @@ For example:
|
||||
... return -(jnp.sin(mat) @ vec)
|
||||
...
|
||||
>>> f(mat, vec)
|
||||
DeviceArray([-1.2655463 , -0.52060574, -0.14522289, -0.10817424,
|
||||
-0.15574613], dtype=float32)
|
||||
Array([-1.2655463 , -0.52060574, -0.14522289, -0.10817424,
|
||||
-0.15574613], dtype=float32)
|
||||
|
||||
>>> mat_sparse = BCOO.fromdense(mat)
|
||||
>>> mat_sparse
|
||||
BCOO(float32[5, 5], nse=8)
|
||||
|
||||
>>> sparsify(f)(mat_sparse, vec)
|
||||
DeviceArray([-1.2655463 , -0.52060574, -0.14522289, -0.10817424,
|
||||
-0.15574613], dtype=float32)
|
||||
Array([-1.2655463 , -0.52060574, -0.14522289, -0.10817424,
|
||||
-0.15574613], dtype=float32)
|
||||
"""
|
||||
|
||||
import functools
|
||||
@ -446,7 +446,7 @@ def sparsify(f, use_tracer=False):
|
||||
>>> v = jnp.array([3, 4, 2])
|
||||
|
||||
>>> f(M, v)
|
||||
DeviceArray([ 64, 82, 100, 118], dtype=int32)
|
||||
Array([ 64, 82, 100, 118], dtype=int32)
|
||||
"""
|
||||
if use_tracer:
|
||||
return _sparsify_with_tracer(f)
|
||||
|
@ -39,23 +39,23 @@ usually generated by the :py:func:`jax.random.PRNGKey` function::
|
||||
>>> from jax import random
|
||||
>>> key = random.PRNGKey(0)
|
||||
>>> key
|
||||
DeviceArray([0, 0], dtype=uint32)
|
||||
Array([0, 0], dtype=uint32)
|
||||
|
||||
This key can then be used in any of JAX's random number generation routines::
|
||||
|
||||
>>> random.uniform(key)
|
||||
DeviceArray(0.41845703, dtype=float32)
|
||||
Array(0.41845703, dtype=float32)
|
||||
|
||||
Note that using a key does not modify it, so reusing the same key will lead to the same result::
|
||||
|
||||
>>> random.uniform(key)
|
||||
DeviceArray(0.41845703, dtype=float32)
|
||||
Array(0.41845703, dtype=float32)
|
||||
|
||||
If you need a new random number, you can use :meth:`jax.random.split` to generate new subkeys::
|
||||
|
||||
>>> key, subkey = random.split(key)
|
||||
>>> random.uniform(subkey)
|
||||
DeviceArray(0.10536897, dtype=float32)
|
||||
Array(0.10536897, dtype=float32)
|
||||
|
||||
Advanced
|
||||
--------
|
||||
|
Loading…
x
Reference in New Issue
Block a user