Replace uses of jnp.ndarray with jax.Array inside JAX.

PiperOrigin-RevId: 509939691
This commit is contained in:
Peter Hawkins 2023-02-15 14:52:31 -08:00 committed by jax authors
parent 7aa7e158f8
commit cd0533cab0
28 changed files with 174 additions and 158 deletions

View File

@ -43,6 +43,7 @@ del _core
# See PEP 484 & https://github.com/google/jax/issues/7570
from jax._src.basearray import Array as Array
from jax import typing as typing
from jax._src.config import (
config as config,

View File

@ -1137,7 +1137,7 @@ def _check_error(error, *, debug=False):
def is_scalar_pred(pred) -> bool:
return (isinstance(pred, bool) or
isinstance(pred, jnp.ndarray) and pred.shape == () and
isinstance(pred, jax.Array) and pred.shape == () and
pred.dtype == jnp.dtype('bool'))

View File

@ -20,7 +20,7 @@ from typing import Any, Dict, Hashable, List, Optional, Protocol, Tuple
import numpy as np
import jax.numpy as jnp
import jax
from jax import tree_util
from jax._src import core
from jax._src import debugging
@ -84,7 +84,7 @@ class DebuggerFrame:
flat_globals, globals_tree = _safe_flatten_dict(self.globals)
flat_vars = flat_locals + flat_globals
is_valid = [
isinstance(l, (core.Tracer, jnp.ndarray, np.ndarray))
isinstance(l, (core.Tracer, jax.Array, np.ndarray))
for l in flat_vars
]
invalid_vars, valid_vars = util.partition_list(is_valid, flat_vars)

View File

@ -257,10 +257,10 @@ class _Subproblem(NamedTuple):
in the workspace.
"""
# The row offset of the block in the matrix of blocks.
offset: jnp.ndarray
offset: jax.Array
# The size of the block.
size: jnp.ndarray
size: jax.Array
@partial(jax.jit, static_argnames=('termination_size',))
def _eigh_work(H, n, termination_size=256):

View File

@ -14,6 +14,7 @@
from typing import Any, Optional, Sequence, Tuple, Union, cast as type_cast
import jax
from jax._src.numpy import lax_numpy as jnp
from jax._src.util import prod
from jax._src.lax import lax
@ -22,7 +23,7 @@ from jax._src.lax import convolution
DType = Any
def conv_general_dilated_patches(
lhs: lax.Array,
lhs: jax.typing.ArrayLike,
filter_shape: Sequence[int],
window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
@ -31,7 +32,7 @@ def conv_general_dilated_patches(
dimension_numbers: Optional[convolution.ConvGeneralDilatedDimensionNumbers] = None,
precision: Optional[lax.PrecisionType] = None,
preferred_element_type: Optional[DType] = None,
) -> lax.Array:
) -> jax.Array:
"""Extract patches subject to the receptive field of `conv_general_dilated`.
Runs the input through a convolution with given parameters. The kernel of the
@ -84,24 +85,25 @@ def conv_general_dilated_patches(
(`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`).
"""
lhs_array = jnp.asarray(lhs)
filter_shape = tuple(filter_shape)
dimension_numbers = convolution.conv_dimension_numbers(
lhs.shape, (1, 1) + filter_shape, dimension_numbers)
lhs_array.shape, (1, 1) + filter_shape, dimension_numbers)
lhs_spec, rhs_spec, out_spec = dimension_numbers
spatial_size = prod(filter_shape)
n_channels = lhs.shape[lhs_spec[1]]
n_channels = lhs_array.shape[lhs_spec[1]]
# Move separate `lhs` spatial locations into separate `rhs` channels.
rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)
rhs = jnp.eye(spatial_size, dtype=lhs_array.dtype).reshape(filter_shape * 2)
rhs = rhs.reshape((spatial_size, 1) + filter_shape)
rhs = jnp.tile(rhs, (n_channels,) + (1,) * (rhs.ndim - 1))
rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1]))
out = convolution.conv_general_dilated(
lhs=lhs,
lhs=lhs_array,
rhs=rhs,
window_strides=window_strides,
padding=padding,
@ -117,8 +119,8 @@ def conv_general_dilated_patches(
def conv_general_dilated_local(
lhs: jnp.ndarray,
rhs: jnp.ndarray,
lhs: jax.typing.ArrayLike,
rhs: jax.typing.ArrayLike,
window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
filter_shape: Sequence[int],
@ -126,7 +128,7 @@ def conv_general_dilated_local(
rhs_dilation: Optional[Sequence[int]] = None,
dimension_numbers: Optional[convolution.ConvGeneralDilatedDimensionNumbers] = None,
precision: lax.PrecisionLike = None
) -> jnp.ndarray:
) -> jax.Array:
"""General n-dimensional unshared convolution operator with optional dilation.
Also known as locally connected layer, the operation is equivalent to
@ -195,6 +197,8 @@ def conv_general_dilated_local(
If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')`
(for a 2D convolution).
"""
lhs_array = jnp.asarray(lhs)
c_precision = lax.canonicalize_precision(precision)
lhs_precision = type_cast(
Optional[lax.PrecisionType],
@ -203,7 +207,7 @@ def conv_general_dilated_local(
else c_precision))
patches = conv_general_dilated_patches(
lhs=lhs,
lhs=lhs_array,
filter_shape=filter_shape,
window_strides=window_strides,
padding=padding,
@ -214,7 +218,7 @@ def conv_general_dilated_local(
)
lhs_spec, rhs_spec, out_spec = convolution.conv_dimension_numbers(
lhs.shape, (1, 1) + tuple(filter_shape), dimension_numbers)
lhs_array.shape, (1, 1) + tuple(filter_shape), dimension_numbers)
lhs_c_dims, rhs_c_dims = [out_spec[1]], [rhs_spec[1]]

View File

@ -94,7 +94,7 @@ class PRNGImpl(NamedTuple):
# -- PRNG key arrays
def _check_prng_key_data(impl, key_data: jnp.ndarray):
def _check_prng_key_data(impl, key_data: jax.Array):
ndim = len(impl.key_shape)
if not all(hasattr(key_data, attr) for attr in ['ndim', 'shape', 'dtype']):
raise TypeError("JAX encountered invalid PRNG key data: expected key_data "
@ -136,7 +136,7 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
"""
impl: PRNGImpl
_base_array: jnp.ndarray
_base_array: jax.Array
def __init__(self, impl, key_data: Any):
assert not isinstance(key_data, core.Tracer)
@ -794,14 +794,14 @@ mlir.register_lowering(random_unwrap_p, random_unwrap_lowering)
# -- threefry2x32 PRNG implementation
def _is_threefry_prng_key(key: jnp.ndarray) -> bool:
def _is_threefry_prng_key(key: jax.Array) -> bool:
try:
return key.shape == (2,) and key.dtype == np.uint32
except AttributeError:
return False
def threefry_seed(seed: jnp.ndarray) -> jnp.ndarray:
def threefry_seed(seed: jax.Array) -> jax.Array:
"""Create a single raw threefry PRNG key from an integer seed.
Args:
@ -1099,26 +1099,26 @@ def threefry_2x32(keypair, count):
return lax.reshape(out[:-1] if odd_size else out, count.shape)
def threefry_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
def threefry_split(key: jax.Array, num: int) -> jax.Array:
if config.jax_threefry_partitionable:
return _threefry_split_foldlike(key, int(num)) # type: ignore
else:
return _threefry_split_original(key, int(num)) # type: ignore
@partial(jit, static_argnums=(1,), inline=True)
def _threefry_split_original(key, num) -> jnp.ndarray:
def _threefry_split_original(key, num) -> jax.Array:
counts = lax.iota(np.uint32, num * 2)
return lax.reshape(threefry_2x32(key, counts), (num, 2))
@partial(jit, static_argnums=(1,), inline=True)
def _threefry_split_foldlike(key, num) -> jnp.ndarray:
def _threefry_split_foldlike(key, num) -> jax.Array:
k1, k2 = key
counts1, counts2 = iota_2x32_shape((num,))
bits1, bits2 = threefry2x32_p.bind(k1, k2, counts1, counts2)
return jnp.stack([bits1, bits2], axis=1)
def threefry_fold_in(key: jnp.ndarray, data: jnp.ndarray) -> jnp.ndarray:
def threefry_fold_in(key: jax.Array, data: jax.Array) -> jax.Array:
assert not data.shape
return _threefry_fold_in(key, jnp.uint32(data))
@ -1127,7 +1127,7 @@ def _threefry_fold_in(key, data):
return threefry_2x32(key, threefry_seed(data))
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
def threefry_random_bits(key: jax.Array, bit_width, shape):
"""Sample uniform random bits of given width and shape using PRNG key."""
if not _is_threefry_prng_key(key):
raise TypeError("threefry_random_bits got invalid prng key.")
@ -1140,7 +1140,7 @@ def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
else:
return _threefry_random_bits_original(key, bit_width, shape)
def _threefry_random_bits_partitionable(key: jnp.ndarray, bit_width, shape):
def _threefry_random_bits_partitionable(key: jax.Array, bit_width, shape):
if all(core.is_constant_dim(d) for d in shape) and prod(shape) > 2 ** 64:
raise NotImplementedError('random bits array of size exceeding 2 ** 64')
@ -1159,7 +1159,7 @@ def _threefry_random_bits_partitionable(key: jnp.ndarray, bit_width, shape):
return lax.convert_element_type(bits1 ^ bits2, dtype)
@partial(jit, static_argnums=(1, 2), inline=True)
def _threefry_random_bits_original(key: jnp.ndarray, bit_width, shape):
def _threefry_random_bits_original(key: jax.Array, bit_width, shape):
size = prod(shape)
# Compute ceil(bit_width * size / 32) in a way that is friendly to shape
# polymorphism
@ -1219,12 +1219,12 @@ threefry_prng_impl = PRNGImpl(
# stable/deterministic across backends or compiler versions. Correspondingly, we
# reserve the right to change any of these implementations at any time!
def _rbg_seed(seed: jnp.ndarray) -> jnp.ndarray:
def _rbg_seed(seed: jax.Array) -> jax.Array:
assert not seed.shape
halfkey = threefry_seed(seed)
return jnp.concatenate([halfkey, halfkey])
def _rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
def _rbg_split(key: jax.Array, num: int) -> jax.Array:
if config.jax_threefry_partitionable:
_threefry_split = _threefry_split_foldlike
else:
@ -1232,12 +1232,12 @@ def _rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
return vmap(
_threefry_split, (0, None), 1)(key.reshape(2, 2), num).reshape(num, 4)
def _rbg_fold_in(key: jnp.ndarray, data: jnp.ndarray) -> jnp.ndarray:
def _rbg_fold_in(key: jax.Array, data: jax.Array) -> jax.Array:
assert not data.shape
return vmap(_threefry_fold_in, (0, None), 0)(key.reshape(2, 2), data).reshape(4)
def _rbg_random_bits(key: jnp.ndarray, bit_width: int, shape: Sequence[int]
) -> jnp.ndarray:
def _rbg_random_bits(key: jax.Array, bit_width: int, shape: Sequence[int]
) -> jax.Array:
if not key.shape == (4,) and key.dtype == jnp.dtype('uint32'):
raise TypeError("_rbg_random_bits got invalid prng key.")
if bit_width not in (8, 16, 32, 64):
@ -1253,12 +1253,12 @@ rbg_prng_impl = PRNGImpl(
fold_in=_rbg_fold_in,
tag='rbg')
def _unsafe_rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
def _unsafe_rbg_split(key: jax.Array, num: int) -> jax.Array:
# treat 10 iterations of random bits as a 'hash function'
_, keys = lax.rng_bit_generator(key, (10 * num, 4), dtype='uint32')
return keys[::10]
def _unsafe_rbg_fold_in(key: jnp.ndarray, data: jnp.ndarray) -> jnp.ndarray:
def _unsafe_rbg_fold_in(key: jax.Array, data: jax.Array) -> jax.Array:
assert not data.shape
_, random_bits = lax.rng_bit_generator(_rbg_seed(data), (10, 4), dtype='uint32')
return key ^ random_bits[-1]

View File

@ -45,19 +45,19 @@ class _BFGSResults(NamedTuple):
line_search_status: int describing line search end state (only means
something if line search fails).
"""
converged: Union[bool, jnp.ndarray]
failed: Union[bool, jnp.ndarray]
k: Union[int, jnp.ndarray]
nfev: Union[int, jnp.ndarray]
ngev: Union[int, jnp.ndarray]
nhev: Union[int, jnp.ndarray]
x_k: jnp.ndarray
f_k: jnp.ndarray
g_k: jnp.ndarray
H_k: jnp.ndarray
old_old_fval: jnp.ndarray
status: Union[int, jnp.ndarray]
line_search_status: Union[int, jnp.ndarray]
converged: Union[bool, jax.Array]
failed: Union[bool, jax.Array]
k: Union[int, jax.Array]
nfev: Union[int, jax.Array]
ngev: Union[int, jax.Array]
nhev: Union[int, jax.Array]
x_k: jax.Array
f_k: jax.Array
g_k: jax.Array
H_k: jax.Array
old_old_fval: jax.Array
status: Union[int, jax.Array]
line_search_status: Union[int, jax.Array]
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
@ -66,7 +66,7 @@ _einsum = partial(jnp.einsum, precision=lax.Precision.HIGHEST)
def minimize_bfgs(
fun: Callable,
x0: jnp.ndarray,
x0: jax.Array,
maxiter: Optional[int] = None,
norm=jnp.inf,
gtol: float = 1e-5,

View File

@ -59,23 +59,23 @@ def _binary_replace(replace_bit, original_dict, new_dict, keys=None):
class _ZoomState(NamedTuple):
done: Union[bool, jnp.ndarray]
failed: Union[bool, jnp.ndarray]
j: Union[int, jnp.ndarray]
a_lo: Union[float, jnp.ndarray]
phi_lo: Union[float, jnp.ndarray]
dphi_lo: Union[float, jnp.ndarray]
a_hi: Union[float, jnp.ndarray]
phi_hi: Union[float, jnp.ndarray]
dphi_hi: Union[float, jnp.ndarray]
a_rec: Union[float, jnp.ndarray]
phi_rec: Union[float, jnp.ndarray]
a_star: Union[float, jnp.ndarray]
phi_star: Union[float, jnp.ndarray]
dphi_star: Union[float, jnp.ndarray]
g_star: Union[float, jnp.ndarray]
nfev: Union[int, jnp.ndarray]
ngev: Union[int, jnp.ndarray]
done: Union[bool, jax.Array]
failed: Union[bool, jax.Array]
j: Union[int, jax.Array]
a_lo: Union[float, jax.Array]
phi_lo: Union[float, jax.Array]
dphi_lo: Union[float, jax.Array]
a_hi: Union[float, jax.Array]
phi_hi: Union[float, jax.Array]
dphi_hi: Union[float, jax.Array]
a_rec: Union[float, jax.Array]
phi_rec: Union[float, jax.Array]
a_star: Union[float, jax.Array]
phi_star: Union[float, jax.Array]
dphi_star: Union[float, jax.Array]
g_star: Union[float, jax.Array]
nfev: Union[int, jax.Array]
ngev: Union[int, jax.Array]
def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo,
@ -215,18 +215,18 @@ def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo,
class _LineSearchState(NamedTuple):
done: Union[bool, jnp.ndarray]
failed: Union[bool, jnp.ndarray]
i: Union[int, jnp.ndarray]
a_i1: Union[float, jnp.ndarray]
phi_i1: Union[float, jnp.ndarray]
dphi_i1: Union[float, jnp.ndarray]
nfev: Union[int, jnp.ndarray]
ngev: Union[int, jnp.ndarray]
a_star: Union[float, jnp.ndarray]
phi_star: Union[float, jnp.ndarray]
dphi_star: Union[float, jnp.ndarray]
g_star: jnp.ndarray
done: Union[bool, jax.Array]
failed: Union[bool, jax.Array]
i: Union[int, jax.Array]
a_i1: Union[float, jax.Array]
phi_i1: Union[float, jax.Array]
dphi_i1: Union[float, jax.Array]
nfev: Union[int, jax.Array]
ngev: Union[int, jax.Array]
a_star: Union[float, jax.Array]
phi_star: Union[float, jax.Array]
dphi_star: Union[float, jax.Array]
g_star: jax.Array
class _LineSearchResults(NamedTuple):
@ -243,15 +243,15 @@ class _LineSearchResults(NamedTuple):
g_k: final gradient value
status: integer end status
"""
failed: Union[bool, jnp.ndarray]
nit: Union[int, jnp.ndarray]
nfev: Union[int, jnp.ndarray]
ngev: Union[int, jnp.ndarray]
k: Union[int, jnp.ndarray]
a_k: Union[int, jnp.ndarray]
f_k: jnp.ndarray
g_k: jnp.ndarray
status: Union[bool, jnp.ndarray]
failed: Union[bool, jax.Array]
nit: Union[int, jax.Array]
nfev: Union[int, jax.Array]
ngev: Union[int, jax.Array]
k: Union[int, jax.Array]
a_k: Union[int, jax.Array]
f_k: jax.Array
g_k: jax.Array
status: Union[bool, jax.Array]
def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Optional, Tuple, Union
import jax
from jax._src.scipy.optimize.bfgs import minimize_bfgs
from jax._src.scipy.optimize._lbfgs import _minimize_lbfgs
from typing import NamedTuple
@ -34,20 +36,20 @@ class OptimizeResults(NamedTuple):
njev: integer number of gradient evaluations.
nit: integer number of iterations of the optimization algorithm.
"""
x: jnp.ndarray
success: Union[bool, jnp.ndarray]
status: Union[int, jnp.ndarray]
fun: jnp.ndarray
jac: jnp.ndarray
hess_inv: Optional[jnp.ndarray]
nfev: Union[int, jnp.ndarray]
njev: Union[int, jnp.ndarray]
nit: Union[int, jnp.ndarray]
x: jax.Array
success: Union[bool, jax.Array]
status: Union[int, jax.Array]
fun: jax.Array
jac: jax.Array
hess_inv: Optional[jax.Array]
nfev: Union[int, jax.Array]
njev: Union[int, jax.Array]
nit: Union[int, jax.Array]
def minimize(
fun: Callable,
x0: jnp.ndarray,
x0: jax.Array,
args: Tuple = (),
*,
method: str,

View File

@ -17,6 +17,8 @@ from functools import partial
import operator
import numpy as np
import jax
import jax.numpy as jnp
from jax import device_put
from jax import lax
@ -87,7 +89,7 @@ def _normalize_matvec(f):
"""Normalize an argument for computing matrix-vector products."""
if callable(f):
return f
elif isinstance(f, (np.ndarray, jnp.ndarray)):
elif isinstance(f, (np.ndarray, jax.Array)):
if f.ndim != 2 or f.shape[0] != f.shape[1]:
raise ValueError(
f'linear operator must be a square matrix, but has shape: {f.shape}')

View File

@ -16,8 +16,8 @@ from collections import namedtuple
from functools import partial
from typing import Optional, Tuple
import jax
import jax.numpy as jnp
import scipy
from jax import jit
from jax._src import dtypes
from jax._src.api import vmap
@ -26,6 +26,8 @@ from jax._src.numpy.util import _wraps
from jax._src.typing import ArrayLike, Array
from jax._src.util import canonicalize_axis, prod
import scipy
ModeResult = namedtuple('ModeResult', ('mode', 'count'))
@_wraps(scipy.stats.mode, lax_description="""\
@ -68,7 +70,7 @@ def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", k
axis = 0
x = x.ravel()
def _mode_helper(x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
def _mode_helper(x: jax.Array) -> Tuple[jax.Array, jax.Array]:
"""Helper function to return mode and count of a given array."""
if x.size == 0:
return jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)), jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_))

View File

@ -19,6 +19,8 @@ https://github.com/google/flax/tree/main/examples/ogbg_molpcba
from typing import Callable, Sequence
from flax import linen as nn
import jax
import jax.numpy as jnp
import jraph
@ -38,7 +40,7 @@ class MLP(nn.Module):
feature_sizes: Sequence[int]
dropout_rate: float = 0
deterministic: bool = True
activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
activation: Callable[[jax.Array], jax.Array] = nn.relu
@nn.compact
def __call__(self, inputs):
@ -131,8 +133,8 @@ class GraphConvNet(nn.Module):
skip_connections: bool = True
layer_norm: bool = True
deterministic: bool = True
pooling_fn: Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray],
jnp.ndarray] = jraph.segment_mean
pooling_fn: Callable[[jax.Array, jax.Array, jax.Array],
jax.Array] = jraph.segment_mean
def pool(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
"""Pooling operation, taken from Jraph."""

View File

@ -153,7 +153,7 @@ def _crop_convert_error(msg: str) -> str:
return msg
def _get_random_data(x: jnp.ndarray) -> np.ndarray:
def _get_random_data(x: jax.Array) -> np.ndarray:
dtype = dtypes.canonicalize_dtype(x.dtype)
if np.issubdtype(dtype, np.integer):
return np.random.randint(0, 100, size=x.shape, dtype=dtype)

View File

@ -16,15 +16,15 @@
import abc
from typing import Sequence, Tuple
import jax
from jax._src import core
import jax.numpy as jnp
from jax._src import util
from jax._src.typing import Array
class JAXSparse(abc.ABC):
"""Base class for high-level JAX sparse objects."""
data: jnp.ndarray
data: jax.Array
shape: Tuple[int, ...]
nse: property
dtype: property

View File

@ -23,6 +23,7 @@ from typing import NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
import jax
import jax.numpy as jnp
from jax import config
from jax import lax
@ -97,7 +98,7 @@ def _compatible(shape1: Sequence[int], shape2: Sequence[int]) -> bool:
return all(s1 in (1, s2) for s1, s2 in safe_zip(shape1, shape2))
def _validate_bcsr_indices(indices: jnp.ndarray, indptr: jnp.ndarray,
def _validate_bcsr_indices(indices: jax.Array, indptr: jax.Array,
shape: Sequence[int]) -> BCSRProperties:
assert jnp.issubdtype(indices.dtype, jnp.integer)
assert jnp.issubdtype(indptr.dtype, jnp.integer)
@ -118,8 +119,8 @@ def _validate_bcsr_indices(indices: jnp.ndarray, indptr: jnp.ndarray,
return BCSRProperties(n_batch=n_batch, n_dense=n_dense, nse=nse)
def _validate_bcsr(data: jnp.ndarray, indices: jnp.ndarray,
indptr: jnp.ndarray, shape: Sequence[int]) -> BCSRProperties:
def _validate_bcsr(data: jax.Array, indices: jax.Array,
indptr: jax.Array, shape: Sequence[int]) -> BCSRProperties:
props = _validate_bcsr_indices(indices, indptr, shape)
shape = tuple(shape)
n_batch, n_dense, nse = props.n_batch, props.n_dense, props.nse
@ -134,8 +135,8 @@ def _validate_bcsr(data: jnp.ndarray, indices: jnp.ndarray,
return props
def _bcsr_to_bcoo(indices: jnp.ndarray, indptr: jnp.ndarray, *,
shape: Sequence[int]) -> jnp.ndarray:
def _bcsr_to_bcoo(indices: jax.Array, indptr: jax.Array, *,
shape: Sequence[int]) -> jax.Array:
"""Given BCSR (indices, indptr), return BCOO (indices)."""
n_batch, _, _ = _validate_bcsr_indices(indices, indptr, shape)
csr_to_coo = nfold_vmap(_csr_to_coo, n_batch)
@ -478,8 +479,8 @@ def bcsr_dot_general(lhs: Union[BCSR, Array], rhs: Array, *,
dense, the result will be dense, of type ndarray.
"""
del precision, preferred_element_type # unused
if isinstance(rhs, (np.ndarray, jnp.ndarray)):
if isinstance(lhs, (np.ndarray, jnp.ndarray)):
if isinstance(rhs, (np.ndarray, jax.Array)):
if isinstance(lhs, (np.ndarray, jax.Array)):
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
if isinstance(lhs, BCSR):
@ -492,8 +493,8 @@ def bcsr_dot_general(lhs: Union[BCSR, Array], rhs: Array, *,
"lhs and ndarray rhs.")
def _bcsr_dot_general(lhs_data: jnp.ndarray, lhs_indices: jnp.ndarray,
lhs_indptr: jnp.ndarray, rhs: Array, *,
def _bcsr_dot_general(lhs_data: jax.Array, lhs_indices: jax.Array,
lhs_indptr: jax.Array, rhs: Array, *,
dimension_numbers: DotDimensionNumbers,
lhs_spinfo: SparseInfo) -> Array:
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
@ -708,9 +709,9 @@ def bcsr_broadcast_in_dim(mat: BCSR, *, shape: Shape, broadcast_dimensions: Sequ
class BCSR(JAXSparse):
"""Experimental batched CSR matrix implemented in JAX."""
data: jnp.ndarray
indices: jnp.ndarray
indptr: jnp.ndarray
data: jax.Array
indices: jax.Array
indptr: jax.Array
shape: Shape
nse = property(lambda self: self.indices.shape[-1])
dtype = property(lambda self: self.data.dtype)

View File

@ -22,6 +22,7 @@ import warnings
import numpy as np
import jax
from jax import lax
from jax.interpreters import mlir
from jax.experimental.sparse._base import JAXSparse
@ -54,9 +55,9 @@ class COO(JAXSparse):
grad and autodiff, and offers very little functionality. In general you
should prefer :class:`jax.experimental.sparse.BCOO`.
"""
data: jnp.ndarray
row: jnp.ndarray
col: jnp.ndarray
data: jax.Array
row: jax.Array
col: jax.Array
shape: Tuple[int, int]
nse = property(lambda self: self.data.size)
dtype = property(lambda self: self.data.dtype)

View File

@ -21,6 +21,7 @@ import warnings
import numpy as np
import jax
from jax.interpreters import mlir
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo
@ -43,9 +44,9 @@ class CSR(JAXSparse):
grad and autodiff, and offers very little functionality. In general you
should prefer :class:`jax.experimental.sparse.BCOO`.
"""
data: jnp.ndarray
indices: jnp.ndarray
indptr: jnp.ndarray
data: jax.Array
indices: jax.Array
indptr: jax.Array
shape: Tuple[int, int]
nse = property(lambda self: self.data.size)
dtype = property(lambda self: self.data.dtype)
@ -129,9 +130,9 @@ class CSR(JAXSparse):
@tree_util.register_pytree_node_class
class CSC(JAXSparse):
"""Experimental CSC matrix implemented in JAX; API subject to change."""
data: jnp.ndarray
indices: jnp.ndarray
indptr: jnp.ndarray
data: jax.Array
indices: jax.Array
indptr: jax.Array
shape: Tuple[int, int]
nse = property(lambda self: self.data.size)
dtype = property(lambda self: self.data.dtype)

View File

@ -31,10 +31,10 @@ from scipy.sparse import csr_matrix, linalg
def lobpcg_standard(
A: Union[jnp.ndarray, Callable[[jnp.ndarray], jnp.ndarray]],
X: jnp.ndarray,
A: Union[jax.Array, Callable[[jax.Array], jax.Array]],
X: jax.Array,
m: int = 100,
tol: Union[jnp.ndarray, float, None] = None):
tol: Union[jax.Array, float, None] = None):
"""Compute the top-k standard eigenvalues using the LOBPCG routine.
LOBPCG [1] stands for Locally Optimal Block Preconditioned Conjugate Gradient.
@ -95,16 +95,16 @@ def lobpcg_standard(
large (only `k * 5 < n` supported), or `k == 0`.
"""
# Jit-compile once per matrix shape if possible.
if isinstance(A, (jnp.ndarray, np.ndarray)):
if isinstance(A, (jax.Array, np.ndarray)):
return _lobpcg_standard_matrix(A, X, m, tol, debug=False)
return _lobpcg_standard_callable(A, X, m, tol, debug=False)
@functools.partial(jax.jit, static_argnames=['m', 'debug'])
def _lobpcg_standard_matrix(
A: jnp.ndarray,
X: jnp.ndarray,
A: jax.Array,
X: jax.Array,
m: int,
tol: Union[jnp.ndarray, float, None],
tol: Union[jax.Array, float, None],
debug: bool = False):
"""Computes lobpcg_standard(), possibly with debug diagnostics."""
return _lobpcg_standard_callable(
@ -112,10 +112,10 @@ def _lobpcg_standard_matrix(
@functools.partial(jax.jit, static_argnames=['A', 'm', 'debug'])
def _lobpcg_standard_callable(
A: Callable[[jnp.ndarray], jnp.ndarray],
X: jnp.ndarray,
A: Callable[[jax.Array], jax.Array],
X: jax.Array,
m: int,
tol: Union[jnp.ndarray, float, None],
tol: Union[jax.Array, float, None],
debug: bool = False):
"""Supports generic lobpcg_standard() callable interface."""

View File

@ -1474,11 +1474,11 @@ class APITest(jtu.JaxTestCase):
"Transpose rule (for reverse-mode differentiation) for 'foo' not implemented")
def test_is_subclass(self):
self.assertTrue(issubclass(device_array.DeviceArray, jnp.ndarray))
self.assertTrue(issubclass(device_array.Buffer, jnp.ndarray))
self.assertTrue(issubclass(pxla.ShardedDeviceArray, jnp.ndarray))
self.assertTrue(issubclass(pxla._ShardedDeviceArray, jnp.ndarray))
self.assertFalse(issubclass(np.ndarray, jnp.ndarray))
self.assertTrue(issubclass(device_array.DeviceArray, jax.Array))
self.assertTrue(issubclass(device_array.Buffer, jax.Array))
self.assertTrue(issubclass(pxla.ShardedDeviceArray, jax.Array))
self.assertTrue(issubclass(pxla._ShardedDeviceArray, jax.Array))
self.assertFalse(issubclass(np.ndarray, jax.Array))
self.assertFalse(issubclass(device_array.DeviceArray, np.ndarray))
self.assertFalse(issubclass(device_array.Buffer, np.ndarray))
self.assertFalse(issubclass(pxla.ShardedDeviceArray, np.ndarray))
@ -1486,7 +1486,7 @@ class APITest(jtu.JaxTestCase):
def test_is_instance(self):
def f(x):
self.assertIsInstance(x, jnp.ndarray)
self.assertIsInstance(x, jax.Array)
self.assertNotIsInstance(x, np.ndarray)
return x + 2
jit(f)(3)
@ -1496,10 +1496,10 @@ class APITest(jtu.JaxTestCase):
x = np.arange(12.).reshape((3, 4)).astype("float32")
dx = api.device_put(x)
_check_instance(self, dx)
self.assertIsInstance(dx, jnp.ndarray)
self.assertIsInstance(dx, jax.Array)
self.assertNotIsInstance(dx, np.ndarray)
x2 = api.device_get(dx)
self.assertNotIsInstance(x2, jnp.ndarray)
self.assertNotIsInstance(x2, jax.Array)
self.assertIsInstance(x2, np.ndarray)
assert np.all(x == x2)
@ -7114,7 +7114,7 @@ class CustomJVPTest(jtu.JaxTestCase):
raise unittest.SkipTest("test only applies when x64 is disabled")
@jax.custom_jvp
def projection_unit_simplex(x: jnp.ndarray) -> jnp.ndarray:
def projection_unit_simplex(x: jax.Array) -> jax.Array:
"""Projection onto the unit simplex."""
s = 1.0
n_features = x.shape[0]

View File

@ -762,7 +762,7 @@ class ShardingTest(jtu.JaxTestCase):
def test_is_subclass(self):
# array version of api_test.py::APITest::test_is_subclass
self.assertTrue(issubclass(array.ArrayImpl, jnp.ndarray))
self.assertTrue(issubclass(array.ArrayImpl, jax.Array))
self.assertFalse(issubclass(array.ArrayImpl, np.ndarray))
def test_op_sharding_sharding_repr(self):

View File

@ -98,7 +98,7 @@ class DtypesTest(jtu.JaxTestCase):
expected_dtype = dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[type_])
for f in [jnp.array, jax.jit(jnp.array), jax.jit(lambda x: x)]:
y = f(type_(0))
self.assertTrue(isinstance(y, jnp.ndarray), msg=(f, y))
self.assertTrue(isinstance(y, jax.Array), msg=(f, y))
self.assertEqual(y.dtype, expected_dtype, msg=(f, y))
def testUnsupportedType(self):
@ -141,7 +141,7 @@ class DtypesTest(jtu.JaxTestCase):
for x, y, dtype in testcases:
x, y = (y, x) if swap else (x, y)
z = op(x, y)
self.assertTrue(isinstance(z, jnp.ndarray), msg=(x, y, z))
self.assertTrue(isinstance(z, jax.Array), msg=(x, y, z))
self.assertEqual(z.dtype, dtypes.canonicalize_dtype(dtype), msg=(x, y, z))
@jax.numpy_dtype_promotion('strict')

View File

@ -160,7 +160,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
if dtype == dtypes.canonicalize_dtype(dtype)])
def testDtypeWrappers(self, dtype):
arr = dtype(0)
self.assertIsInstance(arr, jnp.ndarray)
self.assertIsInstance(arr, jax.Array)
self.assertEqual(arr.dtype, np.dtype(dtype))
self.assertArraysEqual(arr, 0, check_dtypes=False)
@ -3183,7 +3183,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
@jax.jit
def f(x):
self.assertIsInstance(x, jnp.ndarray)
self.assertIsInstance(x, jax.Array)
return jnp.sum(x)
f(arr)
@ -4833,7 +4833,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
check_dtypes=False)
def testBroadcastToOnScalar(self):
self.assertIsInstance(jnp.broadcast_to(10.0, ()), jnp.ndarray)
self.assertIsInstance(jnp.broadcast_to(10.0, ()), jax.Array)
self.assertIsInstance(np.broadcast_to(10.0, ()), np.ndarray)
def testPrecision(self):

View File

@ -2951,7 +2951,7 @@ def take_abstract_eval(x):
class FooArray:
shape: Tuple[int, ...]
data: jnp.ndarray
data: jax.Array
def __init__(self, shape, data):
assert data.shape == (*shape, 2)

View File

@ -274,7 +274,7 @@ class LobpcgTest(jtu.JaxTestCase):
if not os.getenv('LOBPCG_EMIT_DEBUG_PLOTS'):
return
if isinstance(A, (np.ndarray, jnp.ndarray)):
if isinstance(A, (np.ndarray, jax.Array)):
lobpcg = linalg._lobpcg_standard_matrix
else:
lobpcg = linalg._lobpcg_standard_callable

View File

@ -775,7 +775,7 @@ class PythonPmapTest(jtu.JaxTestCase):
# test that we can pass in and out ShardedDeviceArrays
y = f(x)
self.assertIsInstance(y, jnp.ndarray)
self.assertIsInstance(y, jax.Array)
if config.jax_array:
self.assertIsInstance(y, array.ArrayImpl)
else:
@ -1655,7 +1655,7 @@ class PythonPmapTest(jtu.JaxTestCase):
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
y = f(x)
self.assertIsInstance(y, jnp.ndarray)
self.assertIsInstance(y, jax.Array)
if config.jax_array:
self.assertIsInstance(y, array.ArrayImpl)
else:

View File

@ -39,7 +39,7 @@ _QDWH_TEST_EPS = jnp.finfo(_QDWH_TEST_DTYPE).eps
_MAX_LOG_CONDITION_NUM = np.log10(int(1 / _QDWH_TEST_EPS))
def _check_symmetry(x: jnp.ndarray) -> bool:
def _check_symmetry(x: jax.Array) -> bool:
"""Check if the array is symmetric."""
m, n = x.shape
eps = jnp.finfo(x.dtype).eps

View File

@ -506,7 +506,7 @@ class PrngTest(jtu.JaxTestCase):
if not config.jax_enable_custom_prng:
self.skipTest("test requires config.jax_enable_custom_prng")
key = random.PRNGKey(0)
self.assertIsInstance(key, jnp.ndarray)
self.assertIsInstance(key, jax.Array)
class LaxRandomTest(jtu.JaxTestCase):

View File

@ -2381,7 +2381,7 @@ class SparseObjectTest(sptu.SparseTestCase):
with self.subTest('to_elt'):
M_out = vmap(to_elt)(Msp)
self.assertIsInstance(M_out, jnp.ndarray)
self.assertIsInstance(M_out, jax.Array)
self.assertEqual(Msp.shape, M_out.shape)
with self.subTest('axis_None'):