From a419e1917a6d8af7d9115df93daca3754fe35939 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 15 Nov 2022 11:51:55 -0800 Subject: [PATCH] Use jax.Array by default for doctests PiperOrigin-RevId: 488719467 --- .github/workflows/ci-build.yaml | 1 + docs/type_promotion.rst | 8 +-- jax/_src/ad_checkpoint.py | 2 +- jax/_src/api.py | 30 ++++----- jax/_src/custom_derivatives.py | 6 +- jax/_src/errors.py | 28 ++++---- jax/_src/lax/control_flow/loops.py | 4 +- jax/_src/lax/lax.py | 18 ++--- jax/_src/lax/parallel.py | 4 +- jax/_src/lax/slicing.py | 30 ++++----- jax/_src/nn/functions.py | 10 +-- jax/_src/nn/initializers.py | 68 +++++++++---------- jax/_src/numpy/index_tricks.py | 98 ++++++++++++++-------------- jax/_src/numpy/lax_numpy.py | 20 +++--- jax/_src/numpy/polynomial.py | 4 +- jax/_src/ops/scatter.py | 16 ++--- jax/_src/sharding.py | 2 +- jax/_src/tree_util.py | 6 +- jax/experimental/maps.py | 4 +- jax/experimental/sparse/__init__.py | 26 ++++---- jax/experimental/sparse/bcoo.py | 18 ++--- jax/experimental/sparse/transform.py | 10 +-- jax/random.py | 8 +-- 23 files changed, 211 insertions(+), 210 deletions(-) diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index b8c9b7b5b..f8109e51a 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -155,6 +155,7 @@ jobs: - name: Test documentation env: XLA_FLAGS: "--xla_force_host_platform_device_count=8" + JAX_ARRAY: 1 run: | pytest -n 1 --tb=short docs pytest -n 1 --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/gda_serialization --ignore=jax/collect_profile.py diff --git a/docs/type_promotion.rst b/docs/type_promotion.rst index 151436e57..3fbf6e88b 100644 --- a/docs/type_promotion.rst +++ b/docs/type_promotion.rst @@ -181,7 +181,7 @@ equivalent to that of Python scalars, such as the integer scalar ``2`` in the fo >>> x = jnp.arange(5, dtype='int8') >>> 2 * x - DeviceArray([0, 2, 4, 6, 8], dtype=int8) + Array([0, 2, 4, 6, 8], dtype=int8) JAX's weak type framework is designed to prevent unwanted type promotion within binary operations between JAX values and values with no explicitly user-specified type, @@ -191,7 +191,7 @@ the expression above would lead to an implicit type promotion: .. code-block:: python >>> jnp.int32(2) * x - DeviceArray([0, 2, 4, 6, 8], dtype=int32) + Array([0, 2, 4, 6, 8], dtype=int32) When used in JAX, Python scalars are sometimes promoted to :class:`~jax.numpy.DeviceArray` objects, for example during JIT compilation. To maintain the desired promotion @@ -201,7 +201,7 @@ that can be seen in an array's string representation: .. code-block:: python >>> jnp.asarray(2) - DeviceArray(2, dtype=int32, weak_type=True) + Array(2, dtype=int32, weak_type=True) If the ``dtype`` is specified explicitly, it will instead result in a standard strongly-typed array value: @@ -209,4 +209,4 @@ strongly-typed array value: .. code-block:: python >>> jnp.asarray(2, dtype='int32') - DeviceArray(2, dtype=int32) \ No newline at end of file + Array(2, dtype=int32) \ No newline at end of file diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index fcf28b693..51a57659b 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -189,7 +189,7 @@ def checkpoint(fun: Callable, *, prevent_cse: bool = True, ... return z ... >>> jax.value_and_grad(g)(2.0) - (DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(-0.2556391, dtype=float32, weak_type=True)) + (Array(0.78907233, dtype=float32, weak_type=True), Array(-0.2556391, dtype=float32, weak_type=True)) Here, the same value is produced whether or not the :func:`jax.checkpoint` decorator is present. When the decorator is not present, the values diff --git a/jax/_src/api.py b/jax/_src/api.py index 013cc9e7f..53a95dc4e 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -340,7 +340,7 @@ def jit( ... return x >>> >>> g(jnp.arange(4), 3) - DeviceArray([ 0, 1, 256, 6561], dtype=int32) + Array([ 0, 1, 256, 6561], dtype=int32) """ if abstracted_axes and not config.jax_dynamic_shapes: raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes") @@ -1431,14 +1431,14 @@ def hessian(fun: Callable, argnums: Union[int, Sequence[int]] = 0, >>> import jax.numpy as jnp >>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])} >>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.})) - {'c': {'a': {'a': DeviceArray([[[ 2., 0.], [ 0., 0.]], - [[ 0., 0.], [ 0., 12.]]], dtype=float32), - 'b': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]], - [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)}, - 'b': {'a': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]], - [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32), - 'b': DeviceArray([[[0. , 0. ], [0. , 0. ]], - [[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}} + {'c': {'a': {'a': Array([[[ 2., 0.], [ 0., 0.]], + [[ 0., 0.], [ 0., 12.]]], dtype=float32), + 'b': Array([[[ 1. , 0. ], [ 0. , 0. ]], + [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)}, + 'b': {'a': Array([[[ 1. , 0. ], [ 0. , 0. ]], + [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32), + 'b': Array([[[0. , 0. ], [0. , 0. ]], + [[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}} Thus each leaf in the tree structure of ``jax.hessian(fun)(x)`` corresponds to a leaf of ``fun(x)`` and a pair of leaves of ``x``. For each leaf in @@ -1623,13 +1623,13 @@ def vmap(fun: F, (to keep it unmapped). >>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.)) - (DeviceArray([4., 5.], dtype=float32), 8.0) + (Array([4., 5.], dtype=float32), 8.0) If the ``out_axes`` is specified for an unmapped result, the result is broadcast across the mapped axis: >>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.)) - (DeviceArray([4., 5.], dtype=float32), DeviceArray([8., 8.], dtype=float32, weak_type=True)) + (Array([4., 5.], dtype=float32), Array([8., 8.], dtype=float32, weak_type=True)) If the ``out_axes`` is specified for a mapped result, the result is transposed accordingly. @@ -2499,7 +2499,7 @@ def linearize(fun: Callable, *primals) -> Tuple[Any, Callable]: >>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) - (DeviceArray(3.26819, dtype=float32, weak_type=True), DeviceArray(-5.00753, dtype=float32, weak_type=True)) + (Array(3.26819, dtype=float32, weak_type=True), Array(-5.00753, dtype=float32, weak_type=True)) >>> y, f_jvp = jax.linearize(f, 2.) >>> print(y) 3.2681944 @@ -2718,7 +2718,7 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable: >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)) >>> f_transpose = jax.linear_transpose(f, scalar, scalar) >>> f_transpose(1.0) - (DeviceArray(0.5, dtype=float32), DeviceArray(-0.5, dtype=float32)) + (Array(0.5, dtype=float32), Array(-0.5, dtype=float32)) """ reduce_axes = _ensure_str_tuple(reduce_axes) primals_flat, in_tree = tree_flatten(primals) @@ -3046,7 +3046,7 @@ def device_get(x: Any): If ``x`` is a pytree, then the individual buffers are copied in parallel. Args: - x: An array, scalar, DeviceArray or (nested) standard Python container thereof + x: An array, scalar, Array or (nested) standard Python container thereof representing the array to be transferred to host. Returns: @@ -3054,7 +3054,7 @@ def device_get(x: Any): value of ``x``. Examples: - Passing a DeviceArray: + Passing a Array: >>> import jax >>> x = jax.numpy.array([1., 2., 3.]) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 2c99ef99f..aafb93a1b 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -1066,13 +1066,13 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args, ... return transposed >>> div_add(9., 3.) - DeviceArray(12., dtype=float32, weak_type=True) + Array(12., dtype=float32, weak_type=True) >>> transpose(partial(div_add, denom=3.), 1.)(18.) # custom - DeviceArray(24., dtype=float32, weak_type=True) + Array(24., dtype=float32, weak_type=True) >>> transpose(lambda x: x + x / 3., 1.)(18.) # reference - DeviceArray(24., dtype=float32, weak_type=True) + Array(24., dtype=float32, weak_type=True) The above definition of ``f`` illustrates the purpose of a residual argument: division is linear in one of its inputs (the dividend diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 4faa33556..f62645e04 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -69,7 +69,7 @@ class ConcretizationTypeError(JAXTypeError): ... return x.min(axis) >>> func(jnp.arange(4), 0) - DeviceArray(0, dtype=int32) + Array(0, dtype=int32) Traced value used in control flow Another case where this often arises is when a traced value is used in @@ -94,7 +94,7 @@ class ConcretizationTypeError(JAXTypeError): ... return jnp.where(x.sum() < y.sum(), x, y) >>> func(jnp.ones(4), jnp.zeros(4)) - DeviceArray([0., 0., 0., 0.], dtype=float32) + Array([0., 0., 0., 0.], dtype=float32) For more complicated control flow including loops, see :ref:`lax-control-flow`. @@ -140,7 +140,7 @@ class ConcretizationTypeError(JAXTypeError): ... return jnp.where(x > 1, x, 0).sum() >>> func(jnp.arange(4)) - DeviceArray(5, dtype=int32) + Array(5, dtype=int32) To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read @@ -209,7 +209,7 @@ class NonConcreteBooleanIndexError(JAXIndexError): ... return jnp.where(x > 0, x, 0).sum() >>> sum_of_positive(jnp.arange(-5, 5)) - DeviceArray(10, dtype=int32) + Array(10, dtype=int32) This pattern of replacing boolean masking with three-argument :func:`~jax.numpy.where` is a common solution to this sort of problem. @@ -236,7 +236,7 @@ class NonConcreteBooleanIndexError(JAXIndexError): ... return jnp.where(x < 0, 0, x) >>> manual_clip(jnp.arange(-2, 2)) - DeviceArray([0, 0, 0, 1], dtype=int32) + Array([0, 0, 0, 1], dtype=int32) """ def __init__(self, tracer: core.Tracer): super().__init__( @@ -275,7 +275,7 @@ class TracerArrayConversionError(JAXTypeError): ... return jnp.sin(x) >>> func(jnp.arange(4)) - DeviceArray([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32) + Array([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32) Indexing a numpy array with a tracer If this error arises on a line that involves array indexing, it may be that @@ -302,7 +302,7 @@ class TracerArrayConversionError(JAXTypeError): ... return jnp.asarray(x)[i] >>> func(0) - DeviceArray(0, dtype=int32) + Array(0, dtype=int32) or by declaring the index as a static argument:: @@ -311,7 +311,7 @@ class TracerArrayConversionError(JAXTypeError): ... return x[i] >>> func(0) - DeviceArray(0, dtype=int32) + Array(0, dtype=int32) To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read @@ -354,15 +354,15 @@ class TracerIntegerConversionError(JAXTypeError): ... return np.split(x, 2, axis) >>> func(np.arange(10), 0) - [DeviceArray([0, 1, 2, 3, 4], dtype=int32), - DeviceArray([5, 6, 7, 8, 9], dtype=int32)] + [Array([0, 1, 2, 3, 4], dtype=int32), + Array([5, 6, 7, 8, 9], dtype=int32)] An alternative is to apply the transformation to a closure that encapsulates the arguments to be protected, either manually as below or by using :func:`functools.partial`:: >>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4)) - [DeviceArray([0, 1], dtype=int32), DeviceArray([2, 3], dtype=int32)] + [Array([0, 1], dtype=int32), Array([2, 3], dtype=int32)] **Note a new closure is created at every invocation, which defeats the compilation caching mechanism, which is why static_argnums is preferred.** @@ -395,7 +395,7 @@ class TracerIntegerConversionError(JAXTypeError): ... return jnp.array(L)[i] >>> func(0) - DeviceArray(1, dtype=int32) + Array(1, dtype=int32) or by declaring the index as a static argument:: @@ -404,7 +404,7 @@ class TracerIntegerConversionError(JAXTypeError): ... return L[i] >>> func(0) - DeviceArray(1, dtype=int32, weak_type=True) + Array(1, dtype=int32, weak_type=True) To understand more subtleties having to do with tracers vs. regular values, and concrete vs. abstract values, you may want to read @@ -502,7 +502,7 @@ class UnexpectedTracerError(JAXTypeError): >>> y = not_side_effecting(x) >>> outs.append(y) >>> outs[0] + 1 # all good! no longer a leaked value. - DeviceArray(3, dtype=int32, weak_type=True) + Array(3, dtype=int32, weak_type=True) Leak checker As discussed in point 2 and 3 above, JAX shows a reconstructed stack trace diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 2c0ed9f9d..326ea663f 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1782,7 +1782,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0): Example 1: partial sums of an array of numbers: >>> lax.associative_scan(jnp.add, jnp.arange(0, 4)) - DeviceArray([0, 1, 3, 6], dtype=int32) + Array([0, 1, 3, 6], dtype=int32) Example 2: partial products of an array of matrices @@ -1794,7 +1794,7 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0): Example 3: reversed partial sums of an array of numbers >>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True) - DeviceArray([6, 6, 5, 3], dtype=int32) + Array([6, 6, 5, 3], dtype=int32) .. [BLE1990] Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.", Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index c46d7d66b..7f40e3b9d 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -244,9 +244,9 @@ def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array: values which appear as zero in any operations. Consider this example:: >>> jnp.nextafter(0, 1) # denormal numbers are representable - DeviceArray(1.e-45, dtype=float32, weak_type=True) + Array(1.e-45, dtype=float32, weak_type=True) >>> jnp.nextafter(0, 1) * 1 # but are flushed to zero - DeviceArray(0., dtype=float32, weak_type=True) + Array(0., dtype=float32, weak_type=True) For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``. """ @@ -853,18 +853,18 @@ def reshape(operand: ArrayLike, new_sizes: Shape, >>> x = jnp.arange(6) >>> y = reshape(x, (2, 3)) >>> y - DeviceArray([[0, 1, 2], + Array([[0, 1, 2], [3, 4, 5]], dtype=int32) Reshaping back to one dimension: >>> reshape(y, (6,)) - DeviceArray([0, 1, 2, 3, 4, 5], dtype=int32) + Array([0, 1, 2, 3, 4, 5], dtype=int32) Reshaping to one dimension with permutation of dimensions: >>> reshape(y, (6,), (1, 0)) - DeviceArray([0, 3, 1, 4, 2, 5], dtype=int32) + Array([0, 3, 1, 4, 2, 5], dtype=int32) """ new_sizes = canonicalize_shape(new_sizes) # TODO new_sizes = tuple(new_sizes) @@ -1266,13 +1266,13 @@ def stop_gradient(x: T) -> T: For example: >>> jax.grad(lambda x: x**2)(3.) - DeviceArray(6., dtype=float32, weak_type=True) + Array(6., dtype=float32, weak_type=True) >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.) - DeviceArray(0., dtype=float32, weak_type=True) + Array(0., dtype=float32, weak_type=True) >>> jax.grad(jax.grad(lambda x: x**2))(3.) - DeviceArray(2., dtype=float32, weak_type=True) + Array(2., dtype=float32, weak_type=True) >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.) - DeviceArray(0., dtype=float32, weak_type=True) + Array(0., dtype=float32, weak_type=True) """ def stop(x): # only bind primitive on inexact dtypes, to avoid some staging diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 94bd35bf0..efacc52fa 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -386,9 +386,9 @@ def axis_index(axis_name): ... return lax.axis_index('i') ... >>> f(np.zeros(4)) - ShardedDeviceArray([0, 1, 2, 3], dtype=int32) + Array([0, 1, 2, 3], dtype=int32) >>> f(np.zeros(8)) - ShardedDeviceArray([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32) + Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32) >>> @partial(jax.pmap, axis_name='i') ... @partial(jax.pmap, axis_name='j') ... def f(_): diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 7dff20d60..31e3b1ef4 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -84,21 +84,21 @@ def dynamic_slice(operand: Array, start_indices: Union[Array, Sequence[ArrayLike >>> x = jnp.arange(12).reshape(3, 4) >>> x - DeviceArray([[ 0, 1, 2, 3], - [ 4, 5, 6, 7], - [ 8, 9, 10, 11]], dtype=int32) + Array([[ 0, 1, 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11]], dtype=int32) >>> dynamic_slice(x, (1, 1), (2, 3)) - DeviceArray([[ 5, 6, 7], - [ 9, 10, 11]], dtype=int32) + Array([[ 5, 6, 7], + [ 9, 10, 11]], dtype=int32) Note the potentially surprising behavior for the case where the requested slice overruns the bounds of the array; in this case the start index is adjusted to return a slice of the requested size: >>> dynamic_slice(x, (1, 1), (2, 4)) - DeviceArray([[ 4, 5, 6, 7], - [ 8, 9, 10, 11]], dtype=int32) + Array([[ 4, 5, 6, 7], + [ 8, 9, 10, 11]], dtype=int32) """ start_indices = _dynamic_slice_indices(operand, start_indices) if jax.config.jax_dynamic_shapes: @@ -129,25 +129,25 @@ def dynamic_update_slice(operand: Array, update: ArrayLike, >>> x = jnp.zeros(6) >>> y = jnp.ones(3) >>> dynamic_update_slice(x, y, (2,)) - DeviceArray([0., 0., 1., 1., 1., 0.], dtype=float32) + Array([0., 0., 1., 1., 1., 0.], dtype=float32) If the update slice is too large to fit in the array, the start index will be adjusted to make it fit >>> dynamic_update_slice(x, y, (3,)) - DeviceArray([0., 0., 0., 1., 1., 1.], dtype=float32) + Array([0., 0., 0., 1., 1., 1.], dtype=float32) >>> dynamic_update_slice(x, y, (5,)) - DeviceArray([0., 0., 0., 1., 1., 1.], dtype=float32) + Array([0., 0., 0., 1., 1., 1.], dtype=float32) Here is an example of a two-dimensional slice update: >>> x = jnp.zeros((4, 4)) >>> y = jnp.ones((2, 2)) >>> dynamic_update_slice(x, y, (1, 2)) - DeviceArray([[0., 0., 0., 0.], - [0., 0., 1., 1.], - [0., 0., 1., 1.], - [0., 0., 0., 0.]], dtype=float32) + Array([[0., 0., 0., 0.], + [0., 0., 1., 1.], + [0., 0., 1., 1.], + [0., 0., 0., 0.]], dtype=float32) """ start_indices = _dynamic_slice_indices(operand, start_indices) return dynamic_update_slice_p.bind(operand, update, *start_indices) @@ -1563,7 +1563,7 @@ def _clamp_scatter_indices(operand, indices, updates, *, dnums): upper_bounds: core.Shape = tuple(operand.shape[i] - slice_sizes[i] for i in dnums.scatter_dims_to_operand_dims) - # Stack upper_bounds into a DeviceArray[n] + # Stack upper_bounds into a Array[n] upper_bound = lax.shape_as_value(upper_bounds) upper_bound = lax.min(upper_bound, np.iinfo(indices.dtype).max) upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape, diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index d0f232803..24a099f88 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -408,15 +408,15 @@ def one_hot(x: Array, num_classes: int, *, ``num_classes`` with the element at ``index`` set to one:: >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) - DeviceArray([[1., 0., 0.], - [0., 1., 0.], - [0., 0., 1.]], dtype=float32) + Array([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]], dtype=float32) Indicies outside the range [0, num_classes) will be encoded as zeros:: >>> jax.nn.one_hot(jnp.array([-1, 3]), 3) - DeviceArray([[0., 0., 0.], - [0., 0., 0.]], dtype=float32) + Array([[0., 0., 0.], + [0., 0., 0.]], dtype=float32) Args: x: A tensor of indices. diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index c3330ac41..3c1cdb209 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -55,8 +55,8 @@ def zeros(key: KeyArray, >>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32) - DeviceArray([[0., 0., 0.], - [0., 0., 0.]], dtype=float32) + Array([[0., 0., 0.], + [0., 0., 0.]], dtype=float32) """ return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype)) @@ -69,9 +69,9 @@ def ones(key: KeyArray, >>> import jax, jax.numpy as jnp >>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32) - DeviceArray([[1., 1.], - [1., 1.], - [1., 1.]], dtype=float32) + Array([[1., 1.], + [1., 1.], + [1., 1.]], dtype=float32) """ return jnp.ones(shape, dtypes.canonicalize_dtype(dtype)) @@ -87,8 +87,8 @@ def constant(value: Array, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.constant(-7) >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) - DeviceArray([[-7., -7., -7.], - [-7., -7., -7.]], dtype=float32) + Array([[-7., -7., -7.], + [-7., -7., -7.]], dtype=float32) """ def init(key: KeyArray, shape: core.Shape, @@ -112,8 +112,8 @@ def uniform(scale: RealNumeric = 1e-2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.uniform(10.0) >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP - DeviceArray([[7.298188 , 8.691938 , 8.7230015], - [2.0818567, 1.8662417, 5.5022564]], dtype=float32) + Array([[7.298188 , 8.691938 , 8.7230015], + [2.0818567, 1.8662417, 5.5022564]], dtype=float32) """ def init(key: KeyArray, shape: core.Shape, @@ -137,8 +137,8 @@ def normal(stddev: RealNumeric = 1e-2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.normal(5.0) >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP - DeviceArray([[ 3.0613258 , 5.6129413 , 5.6866574 ], - [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32) + Array([[ 3.0613258 , 5.6129413 , 5.6866574 ], + [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32) """ def init(key: KeyArray, shape: core.Shape, @@ -319,8 +319,8 @@ def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_uniform() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP - DeviceArray([[ 0.50350785, 0.8088631 , 0.81566876], - [-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32) + Array([[ 0.50350785, 0.8088631 , 0.81566876], + [-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32) .. _Glorot uniform initializer: http://proceedings.mlr.press/v9/glorot10a.html """ @@ -357,8 +357,8 @@ def glorot_normal(in_axis: Union[int, Sequence[int]] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_normal() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP - DeviceArray([[ 0.41770416, 0.75262755, 0.7619329 ], - [-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32) + Array([[ 0.41770416, 0.75262755, 0.7619329 ], + [-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32) .. _Glorot normal initializer: http://proceedings.mlr.press/v9/glorot10a.html """ @@ -394,8 +394,8 @@ def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_uniform() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP - DeviceArray([[ 0.56293887, 0.90433645, 0.9119454 ], - [-0.71479625, -0.7676109 , 0.12302713]], dtype=float32) + Array([[ 0.56293887, 0.90433645, 0.9119454 ], + [-0.71479625, -0.7676109 , 0.12302713]], dtype=float32) .. _Lecun uniform initializer: https://arxiv.org/abs/1706.02515 """ @@ -429,8 +429,8 @@ def lecun_normal(in_axis: Union[int, Sequence[int]] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_normal() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP - DeviceArray([[ 0.46700746, 0.8414632 , 0.8518669 ], - [-0.61677957, -0.67402434, 0.09683388]], dtype=float32) + Array([[ 0.46700746, 0.8414632 , 0.8518669 ], + [-0.61677957, -0.67402434, 0.09683388]], dtype=float32) .. _Lecun normal initializer: https://arxiv.org/abs/1706.02515 """ @@ -465,8 +465,8 @@ def he_uniform(in_axis: Union[int, Sequence[int]] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.kaiming_uniform() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP - DeviceArray([[ 0.79611576, 1.2789248 , 1.2896855 ], - [-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32) + Array([[ 0.79611576, 1.2789248 , 1.2896855 ], + [-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32) .. _He uniform initializer: https://arxiv.org/abs/1502.01852 """ @@ -503,8 +503,8 @@ def he_normal(in_axis: Union[int, Sequence[int]] = -2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.kaiming_normal() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP - DeviceArray([[ 0.6604483 , 1.1900088 , 1.2047218 ], - [-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32) + Array([[ 0.6604483 , 1.1900088 , 1.2047218 ], + [-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32) .. _He normal initializer: https://arxiv.org/abs/1502.01852 """ @@ -536,8 +536,8 @@ def orthogonal(scale: RealNumeric = 1.0, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.orthogonal() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP - DeviceArray([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], - [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32) + Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], + [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32) """ def init(key: KeyArray, shape: core.Shape, @@ -579,17 +579,17 @@ def delta_orthogonal( >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.delta_orthogonal() >>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32) # doctest: +SKIP - DeviceArray([[[ 0. , 0. , 0. ], - [ 0. , 0. , 0. ], - [ 0. , 0. , 0. ]], + Array([[[ 0. , 0. , 0. ], + [ 0. , 0. , 0. ], + [ 0. , 0. , 0. ]], - [[ 0.27858758, -0.7949833 , -0.53887904], - [ 0.9120717 , 0.04322892, 0.40774566], - [-0.30085585, -0.6050892 , 0.73712474]], + [[ 0.27858758, -0.7949833 , -0.53887904], + [ 0.9120717 , 0.04322892, 0.40774566], + [-0.30085585, -0.6050892 , 0.73712474]], - [[ 0. , 0. , 0. ], - [ 0. , 0. , 0. ], - [ 0. , 0. , 0. ]]], dtype=float32) + [[ 0. , 0. , 0. ], + [ 0. , 0. , 0. ], + [ 0. , 0. , 0. ]]], dtype=float32) .. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393 diff --git a/jax/_src/numpy/index_tricks.py b/jax/_src/numpy/index_tricks.py index c0affbb83..5fff1df13 100644 --- a/jax/_src/numpy/index_tricks.py +++ b/jax/_src/numpy/index_tricks.py @@ -70,20 +70,20 @@ class _Mgrid(_IndexGrid): Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`: >>> jnp.mgrid[0:4:1] - DeviceArray([0, 1, 2, 3], dtype=int32) + Array([0, 1, 2, 3], dtype=int32) Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`: >>> jnp.mgrid[0:1:4j] - DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32) + Array([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32) Multiple slices can be used to create broadcasted grids of indices: >>> jnp.mgrid[:2, :3] - DeviceArray([[[0, 0, 0], - [1, 1, 1]], - [[0, 1, 2], - [0, 1, 2]]], dtype=int32) + Array([[[0, 0, 0], + [1, 1, 1]], + [[0, 1, 2], + [0, 1, 2]]], dtype=int32) """ sparse = False op_name = "mgrid" @@ -105,19 +105,19 @@ class _Ogrid(_IndexGrid): Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`: >>> jnp.ogrid[0:4:1] - DeviceArray([0, 1, 2, 3], dtype=int32) + Array([0, 1, 2, 3], dtype=int32) Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`: >>> jnp.ogrid[0:1:4j] - DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32) + Array([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32) Multiple slices can be used to create sparse grids of indices: >>> jnp.ogrid[:2, :3] - [DeviceArray([[0], - [1]], dtype=int32), - DeviceArray([[0, 1, 2]], dtype=int32)] + [Array([[0], + [1]], dtype=int32), + Array([[0, 1, 2]], dtype=int32)] """ sparse = True op_name = "ogrid" @@ -213,56 +213,56 @@ class RClass(_AxisConcat): Passing slices in the form ``[start:stop:step]`` generates ``jnp.arange`` objects: >>> jnp.r_[-1:5:1, 0, 0, jnp.array([1,2,3])] - DeviceArray([-1, 0, 1, 2, 3, 4, 0, 0, 1, 2, 3], dtype=int32) + Array([-1, 0, 1, 2, 3, 4, 0, 0, 1, 2, 3], dtype=int32) An imaginary value for ``step`` will create a ``jnp.linspace`` object instead, which includes the right endpoint: >>> jnp.r_[-1:1:6j, 0, jnp.array([1,2,3])] - DeviceArray([-1. , -0.6 , -0.20000002, 0.20000005, - 0.6 , 1. , 0. , 1. , - 2. , 3. ], dtype=float32) + Array([-1. , -0.6 , -0.20000002, 0.20000005, + 0.6 , 1. , 0. , 1. , + 2. , 3. ], dtype=float32) Use a string directive of the form ``"axis,dims,trans1d"`` as the first argument to specify concatenation axis, minimum number of dimensions, and the position of the upgraded array's original dimensions in the resulting array's shape tuple: >>> jnp.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, 2D output - DeviceArray([[1, 2, 3], - [4, 5, 6]], dtype=int32) + Array([[1, 2, 3], + [4, 5, 6]], dtype=int32) >>> jnp.r_['0,2,0', [1,2,3], [4,5,6]] # push last input axis to the front - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) + Array([[1], + [2], + [3], + [4], + [5], + [6]], dtype=int32) Negative values for ``trans1d`` offset the last axis towards the start of the shape tuple: >>> jnp.r_['0,2,-2', [1,2,3], [4,5,6]] - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) + Array([[1], + [2], + [3], + [4], + [5], + [6]], dtype=int32) Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs to create an array with an extra row or column axis, respectively: >>> jnp.r_['r',[1,2,3], [4,5,6]] - DeviceArray([[1, 2, 3, 4, 5, 6]], dtype=int32) + Array([[1, 2, 3, 4, 5, 6]], dtype=int32) >>> jnp.r_['c',[1,2,3], [4,5,6]] - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) + Array([[1], + [2], + [3], + [4], + [5], + [6]], dtype=int32) For higher-dimensional inputs (``dim >= 2``), both directives ``"r"`` and ``"c"`` give the same result. @@ -288,32 +288,32 @@ class CClass(_AxisConcat): >>> a = jnp.arange(6).reshape((2,3)) >>> jnp.c_[a,a] - DeviceArray([[0, 1, 2, 0, 1, 2], - [3, 4, 5, 3, 4, 5]], dtype=int32) + Array([[0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5]], dtype=int32) Use a string directive of the form ``"axis:dims:trans1d"`` as the first argument to specify concatenation axis, minimum number of dimensions, and the position of the upgraded array's original dimensions in the resulting array's shape tuple: >>> jnp.c_['0,2', [1,2,3], [4,5,6]] - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) + Array([[1], + [2], + [3], + [4], + [5], + [6]], dtype=int32) >>> jnp.c_['0,2,-1', [1,2,3], [4,5,6]] - DeviceArray([[1, 2, 3], - [4, 5, 6]], dtype=int32) + Array([[1, 2, 3], + [4, 5, 6]], dtype=int32) Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs to create an array with inputs stacked along the last axis: >>> jnp.c_['r',[1,2,3], [4,5,6]] - DeviceArray([[1, 4], - [2, 5], - [3, 6]], dtype=int32) + Array([[1, 4], + [2, 5], + [3, 6]], dtype=int32) """ axis = -1 ndmin = 2 diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 52798174f..1c3a5ad29 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -277,12 +277,12 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array: >>> val = jnp.uint32(0xFFFFFFFF) >>> val.astype('int32') - DeviceArray(-1, dtype=int32) + Array(-1, dtype=int32) This function clips to the values representable in the new type: >>> _convert_and_clip_integer(val, 'int32') - DeviceArray(2147483647, dtype=int32) + Array(2147483647, dtype=int32) """ val = val if isinstance(val, ndarray) else asarray(val) dtype = dtypes.canonicalize_dtype(dtype) @@ -5158,21 +5158,21 @@ class _IndexUpdateHelper: -------- >>> x = jnp.arange(5.0) >>> x - DeviceArray([0., 1., 2., 3., 4.], dtype=float32) + Array([0., 1., 2., 3., 4.], dtype=float32) >>> x.at[2].add(10) - DeviceArray([ 0., 1., 12., 3., 4.], dtype=float32) + Array([ 0., 1., 12., 3., 4.], dtype=float32) >>> x.at[10].add(10) # out-of-bounds indices are ignored - DeviceArray([0., 1., 2., 3., 4.], dtype=float32) + Array([0., 1., 2., 3., 4.], dtype=float32) >>> x.at[20].add(10, mode='clip') - DeviceArray([ 0., 1., 2., 3., 14.], dtype=float32) + Array([ 0., 1., 2., 3., 14.], dtype=float32) >>> x.at[2].get() - DeviceArray(2., dtype=float32) + Array(2., dtype=float32) >>> x.at[20].get() # out-of-bounds indices clipped - DeviceArray(4., dtype=float32) + Array(4., dtype=float32) >>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN - DeviceArray(nan, dtype=float32) + Array(nan, dtype=float32) >>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value - DeviceArray(-1., dtype=float32) + Array(-1., dtype=float32) """ __slots__ = ("array",) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 56be67932..23fa110c8 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -65,11 +65,11 @@ roots will be padded with NaN values: # The default behavior matches numpy and strips leading zeros: >>> jnp.roots(coeffs) -DeviceArray([-2.+0.j], dtype=complex64) +Array([-2.+0.j], dtype=complex64) # With strip_zeros=False, extra roots are set to NaN: >>> jnp.roots(coeffs, strip_zeros=False) -DeviceArray([-2. +0.j, nan+nanj], dtype=complex64) +Array([-2. +0.j, nan+nanj], dtype=complex64) """, extra_params=""" strip_zeros : bool, default=True diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index d7ab61175..fed929a12 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -229,13 +229,13 @@ def segment_sum(data: Array, >>> data = jnp.arange(5) >>> segment_ids = jnp.array([0, 0, 1, 1, 2]) >>> segment_sum(data, segment_ids) - DeviceArray([1, 5, 4], dtype=int32) + Array([1, 5, 4], dtype=int32) Using JIT requires static `num_segments`: >>> from jax import jit >>> jit(segment_sum, static_argnums=2)(data, segment_ids, 3) - DeviceArray([1, 5, 4], dtype=int32) + Array([1, 5, 4], dtype=int32) """ return _segment_update( "segment_sum", data, segment_ids, lax.scatter_add, num_segments, @@ -285,13 +285,13 @@ def segment_prod(data: Array, >>> data = jnp.arange(6) >>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2]) >>> segment_prod(data, segment_ids) - DeviceArray([ 0, 6, 20], dtype=int32) + Array([ 0, 6, 20], dtype=int32) Using JIT requires static `num_segments`: >>> from jax import jit >>> jit(segment_prod, static_argnums=2)(data, segment_ids, 3) - DeviceArray([ 0, 6, 20], dtype=int32) + Array([ 0, 6, 20], dtype=int32) """ return _segment_update( "segment_prod", data, segment_ids, lax.scatter_mul, num_segments, @@ -340,13 +340,13 @@ def segment_max(data: Array, >>> data = jnp.arange(6) >>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2]) >>> segment_max(data, segment_ids) - DeviceArray([1, 3, 5], dtype=int32) + Array([1, 3, 5], dtype=int32) Using JIT requires static `num_segments`: >>> from jax import jit >>> jit(segment_max, static_argnums=2)(data, segment_ids, 3) - DeviceArray([1, 3, 5], dtype=int32) + Array([1, 3, 5], dtype=int32) """ return _segment_update( "segment_max", data, segment_ids, lax.scatter_max, num_segments, @@ -395,13 +395,13 @@ def segment_min(data: Array, >>> data = jnp.arange(6) >>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2]) >>> segment_min(data, segment_ids) - DeviceArray([0, 2, 4], dtype=int32) + Array([0, 2, 4], dtype=int32) Using JIT requires static `num_segments`: >>> from jax import jit >>> jit(segment_min, static_argnums=2)(data, segment_ids, 3) - DeviceArray([0, 2, 4], dtype=int32) + Array([0, 2, 4], dtype=int32) """ return _segment_update( "segment_min", data, segment_ids, lax.scatter_min, num_segments, diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index f04cf76cb..8e2662322 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -201,7 +201,7 @@ def _enable_cpp_named_sharding(): @pxla.use_cpp_class(_enable_cpp_named_sharding()) class NamedSharding(XLACompatibleSharding): - """NamedSharding is a way to express ``Sharding``s using named axes. + r"""NamedSharding is a way to express ``Sharding``\s using named axes. ``Mesh`` and ``PartitionSpec`` can be used to express a ``Sharding`` with a name. diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index cc550faf5..d685c83fb 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -324,7 +324,7 @@ class Partial(functools.partial): >>> import jax.numpy as jnp >>> add_one = Partial(jnp.add, 1) >>> add_one(2) - DeviceArray(3, dtype=int32, weak_type=True) + Array(3, dtype=int32, weak_type=True) Pytree compatibility means that the resulting partial function can be passed as an argument within transformed JAX functions, which is not possible with a @@ -336,13 +336,13 @@ class Partial(functools.partial): ... return f(*args) ... >>> call_func(add_one, 2) - DeviceArray(3, dtype=int32, weak_type=True) + Array(3, dtype=int32, weak_type=True) Passing zero arguments to ``Partial`` effectively wraps the original function, making it a valid argument in JAX transformed functions: >>> call_func(Partial(jnp.add), 1, 2) - DeviceArray(3, dtype=int32, weak_type=True) + Array(3, dtype=int32, weak_type=True) Had we passed ``jnp.add`` to ``call_func`` directly, it would have resulted in a ``TypeError``. diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index af310ba8a..5204462ff 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -415,8 +415,8 @@ def xmap(fun: Callable, >>> xmap(jnp.vdot, ... in_axes=({0: 'left'}, {1: 'right'}), ... out_axes=['left', 'right', ...])(x, x.T) - DeviceArray([[ 30, 80], - [ 80, 255]], dtype=int32) + Array([[ 30, 80], + [ 80, 255]], dtype=int32) Note that the contraction in the program is performed over the positional axes, while named axes are just a convenient way to achieve batching. While this diff --git a/jax/experimental/sparse/__init__.py b/jax/experimental/sparse/__init__.py index 0e40ae837..bee4cb5af 100644 --- a/jax/experimental/sparse/__init__.py +++ b/jax/experimental/sparse/__init__.py @@ -45,21 +45,21 @@ Here is an example of creating a sparse array from a dense array: Convert back to a dense array with the ``todense()`` method: >>> M_sp.todense() - DeviceArray([[0., 1., 0., 2.], - [3., 0., 0., 0.], - [0., 0., 4., 0.]], dtype=float32) + Array([[0., 1., 0., 2.], + [3., 0., 0., 0.], + [0., 0., 4., 0.]], dtype=float32) The BCOO format is a somewhat modified version of the standard COO format, and the dense representation can be seen in the ``data`` and ``indices`` attributes: >>> M_sp.data # Explicitly stored data - DeviceArray([1., 2., 3., 4.], dtype=float32) + Array([1., 2., 3., 4.], dtype=float32) >>> M_sp.indices # Indices of the stored data - DeviceArray([[0, 1], - [0, 3], - [1, 0], - [2, 2]], dtype=int32) + Array([[0, 1], + [0, 3], + [1, 0], + [2, 2]], dtype=int32) BCOO objects have familiar array-like attributes, as well as sparse-specific attributes: @@ -82,10 +82,10 @@ product: >>> y = jnp.array([3., 6., 5.]) >>> M_sp.T @ y - DeviceArray([18., 3., 20., 6.], dtype=float32) + Array([18., 3., 20., 6.], dtype=float32) >>> M.T @ y # Compare to dense version - DeviceArray([18., 3., 20., 6.], dtype=float32) + Array([18., 3., 20., 6.], dtype=float32) BCOO objects are designed to be compatible with JAX transforms, including :func:`jax.jit`, :func:`jax.vmap`, :func:`jax.grad`, and others. For example: @@ -96,7 +96,7 @@ BCOO objects are designed to be compatible with JAX transforms, including :func: ... return (M_sp.T @ y).sum() ... >>> jit(grad(f))(y) - DeviceArray([3., 3., 4.], dtype=float32) + Array([3., 3., 4.], dtype=float32) Note, however, that under normal circumstances :mod:`jax.numpy` and :mod:`jax.lax` functions do not know how to handle sparse matrices, so attempting to compute things like @@ -114,7 +114,7 @@ Consider this function, which computes a more complicated result from a matrix a ... return 2 * jnp.dot(jnp.log1p(M.T), v) + 1 ... >>> f(M, y) - DeviceArray([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32) + Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32) Were we to pass a sparse matrix to this directly, it would result in an error, because ``jnp`` functions do not recognize sparse inputs. However, with :func:`sparsify`, we get a version of @@ -123,7 +123,7 @@ this function that does accept sparse matrices: >>> f_sp = sparse.sparsify(f) >>> f_sp(M_sp, y) - DeviceArray([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32) + Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32) Currently support for :func:`sparsify` is limited to a couple dozen primitives, including: diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 5c0b01b0b..9e3df2b27 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -2331,17 +2331,17 @@ class BCOO(JAXSparse): Examine the internal representation: >>> M_sp.data - DeviceArray([2., 1., 4.], dtype=float32) + Array([2., 1., 4.], dtype=float32) >>> M_sp.indices - DeviceArray([[0, 1], - [1, 0], - [1, 2]], dtype=int32) + Array([[0, 1], + [1, 0], + [1, 2]], dtype=int32) Create a dense array from a sparse array: >>> M_sp.todense() - DeviceArray([[0., 2., 0.], - [1., 0., 4.]], dtype=float32) + Array([[0., 2., 0.], + [1., 0., 4.]], dtype=float32) Create a sparse array from COO data & indices: @@ -2353,9 +2353,9 @@ class BCOO(JAXSparse): >>> mat BCOO(float32[3, 3], nse=3) >>> mat.todense() - DeviceArray([[1., 0., 0.], - [0., 3., 0.], - [0., 0., 5.]], dtype=float32) + Array([[1., 0., 0.], + [0., 3., 0.], + [0., 0., 5.]], dtype=float32) """ # Note: additional BCOO methods are defined in transform.py diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 33030d731..071e2669a 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -34,16 +34,16 @@ For example: ... return -(jnp.sin(mat) @ vec) ... >>> f(mat, vec) -DeviceArray([-1.2655463 , -0.52060574, -0.14522289, -0.10817424, - -0.15574613], dtype=float32) +Array([-1.2655463 , -0.52060574, -0.14522289, -0.10817424, + -0.15574613], dtype=float32) >>> mat_sparse = BCOO.fromdense(mat) >>> mat_sparse BCOO(float32[5, 5], nse=8) >>> sparsify(f)(mat_sparse, vec) -DeviceArray([-1.2655463 , -0.52060574, -0.14522289, -0.10817424, - -0.15574613], dtype=float32) +Array([-1.2655463 , -0.52060574, -0.14522289, -0.10817424, + -0.15574613], dtype=float32) """ import functools @@ -446,7 +446,7 @@ def sparsify(f, use_tracer=False): >>> v = jnp.array([3, 4, 2]) >>> f(M, v) - DeviceArray([ 64, 82, 100, 118], dtype=int32) + Array([ 64, 82, 100, 118], dtype=int32) """ if use_tracer: return _sparsify_with_tracer(f) diff --git a/jax/random.py b/jax/random.py index 6a12b04c0..90a96db9b 100644 --- a/jax/random.py +++ b/jax/random.py @@ -39,23 +39,23 @@ usually generated by the :py:func:`jax.random.PRNGKey` function:: >>> from jax import random >>> key = random.PRNGKey(0) >>> key - DeviceArray([0, 0], dtype=uint32) + Array([0, 0], dtype=uint32) This key can then be used in any of JAX's random number generation routines:: >>> random.uniform(key) - DeviceArray(0.41845703, dtype=float32) + Array(0.41845703, dtype=float32) Note that using a key does not modify it, so reusing the same key will lead to the same result:: >>> random.uniform(key) - DeviceArray(0.41845703, dtype=float32) + Array(0.41845703, dtype=float32) If you need a new random number, you can use :meth:`jax.random.split` to generate new subkeys:: >>> key, subkey = random.split(key) >>> random.uniform(subkey) - DeviceArray(0.10536897, dtype=float32) + Array(0.10536897, dtype=float32) Advanced --------