mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Replace uses of jnp.ndarray with jax.Array inside JAX.
PiperOrigin-RevId: 509939691
This commit is contained in:
parent
7aa7e158f8
commit
cd0533cab0
@ -43,6 +43,7 @@ del _core
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
from jax._src.basearray import Array as Array
|
||||
from jax import typing as typing
|
||||
|
||||
from jax._src.config import (
|
||||
config as config,
|
||||
|
@ -1137,7 +1137,7 @@ def _check_error(error, *, debug=False):
|
||||
|
||||
def is_scalar_pred(pred) -> bool:
|
||||
return (isinstance(pred, bool) or
|
||||
isinstance(pred, jnp.ndarray) and pred.shape == () and
|
||||
isinstance(pred, jax.Array) and pred.shape == () and
|
||||
pred.dtype == jnp.dtype('bool'))
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ from typing import Any, Dict, Hashable, List, Optional, Protocol, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax
|
||||
from jax import tree_util
|
||||
from jax._src import core
|
||||
from jax._src import debugging
|
||||
@ -84,7 +84,7 @@ class DebuggerFrame:
|
||||
flat_globals, globals_tree = _safe_flatten_dict(self.globals)
|
||||
flat_vars = flat_locals + flat_globals
|
||||
is_valid = [
|
||||
isinstance(l, (core.Tracer, jnp.ndarray, np.ndarray))
|
||||
isinstance(l, (core.Tracer, jax.Array, np.ndarray))
|
||||
for l in flat_vars
|
||||
]
|
||||
invalid_vars, valid_vars = util.partition_list(is_valid, flat_vars)
|
||||
|
@ -257,10 +257,10 @@ class _Subproblem(NamedTuple):
|
||||
in the workspace.
|
||||
"""
|
||||
# The row offset of the block in the matrix of blocks.
|
||||
offset: jnp.ndarray
|
||||
offset: jax.Array
|
||||
|
||||
# The size of the block.
|
||||
size: jnp.ndarray
|
||||
size: jax.Array
|
||||
|
||||
@partial(jax.jit, static_argnames=('termination_size',))
|
||||
def _eigh_work(H, n, termination_size=256):
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
|
||||
from typing import Any, Optional, Sequence, Tuple, Union, cast as type_cast
|
||||
import jax
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.util import prod
|
||||
from jax._src.lax import lax
|
||||
@ -22,7 +23,7 @@ from jax._src.lax import convolution
|
||||
DType = Any
|
||||
|
||||
def conv_general_dilated_patches(
|
||||
lhs: lax.Array,
|
||||
lhs: jax.typing.ArrayLike,
|
||||
filter_shape: Sequence[int],
|
||||
window_strides: Sequence[int],
|
||||
padding: Union[str, Sequence[Tuple[int, int]]],
|
||||
@ -31,7 +32,7 @@ def conv_general_dilated_patches(
|
||||
dimension_numbers: Optional[convolution.ConvGeneralDilatedDimensionNumbers] = None,
|
||||
precision: Optional[lax.PrecisionType] = None,
|
||||
preferred_element_type: Optional[DType] = None,
|
||||
) -> lax.Array:
|
||||
) -> jax.Array:
|
||||
"""Extract patches subject to the receptive field of `conv_general_dilated`.
|
||||
|
||||
Runs the input through a convolution with given parameters. The kernel of the
|
||||
@ -84,24 +85,25 @@ def conv_general_dilated_patches(
|
||||
(`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`).
|
||||
|
||||
"""
|
||||
lhs_array = jnp.asarray(lhs)
|
||||
filter_shape = tuple(filter_shape)
|
||||
dimension_numbers = convolution.conv_dimension_numbers(
|
||||
lhs.shape, (1, 1) + filter_shape, dimension_numbers)
|
||||
lhs_array.shape, (1, 1) + filter_shape, dimension_numbers)
|
||||
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
|
||||
spatial_size = prod(filter_shape)
|
||||
n_channels = lhs.shape[lhs_spec[1]]
|
||||
n_channels = lhs_array.shape[lhs_spec[1]]
|
||||
|
||||
# Move separate `lhs` spatial locations into separate `rhs` channels.
|
||||
rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)
|
||||
rhs = jnp.eye(spatial_size, dtype=lhs_array.dtype).reshape(filter_shape * 2)
|
||||
|
||||
rhs = rhs.reshape((spatial_size, 1) + filter_shape)
|
||||
rhs = jnp.tile(rhs, (n_channels,) + (1,) * (rhs.ndim - 1))
|
||||
rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1]))
|
||||
|
||||
out = convolution.conv_general_dilated(
|
||||
lhs=lhs,
|
||||
lhs=lhs_array,
|
||||
rhs=rhs,
|
||||
window_strides=window_strides,
|
||||
padding=padding,
|
||||
@ -117,8 +119,8 @@ def conv_general_dilated_patches(
|
||||
|
||||
|
||||
def conv_general_dilated_local(
|
||||
lhs: jnp.ndarray,
|
||||
rhs: jnp.ndarray,
|
||||
lhs: jax.typing.ArrayLike,
|
||||
rhs: jax.typing.ArrayLike,
|
||||
window_strides: Sequence[int],
|
||||
padding: Union[str, Sequence[Tuple[int, int]]],
|
||||
filter_shape: Sequence[int],
|
||||
@ -126,7 +128,7 @@ def conv_general_dilated_local(
|
||||
rhs_dilation: Optional[Sequence[int]] = None,
|
||||
dimension_numbers: Optional[convolution.ConvGeneralDilatedDimensionNumbers] = None,
|
||||
precision: lax.PrecisionLike = None
|
||||
) -> jnp.ndarray:
|
||||
) -> jax.Array:
|
||||
"""General n-dimensional unshared convolution operator with optional dilation.
|
||||
|
||||
Also known as locally connected layer, the operation is equivalent to
|
||||
@ -195,6 +197,8 @@ def conv_general_dilated_local(
|
||||
If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')`
|
||||
(for a 2D convolution).
|
||||
"""
|
||||
lhs_array = jnp.asarray(lhs)
|
||||
|
||||
c_precision = lax.canonicalize_precision(precision)
|
||||
lhs_precision = type_cast(
|
||||
Optional[lax.PrecisionType],
|
||||
@ -203,7 +207,7 @@ def conv_general_dilated_local(
|
||||
else c_precision))
|
||||
|
||||
patches = conv_general_dilated_patches(
|
||||
lhs=lhs,
|
||||
lhs=lhs_array,
|
||||
filter_shape=filter_shape,
|
||||
window_strides=window_strides,
|
||||
padding=padding,
|
||||
@ -214,7 +218,7 @@ def conv_general_dilated_local(
|
||||
)
|
||||
|
||||
lhs_spec, rhs_spec, out_spec = convolution.conv_dimension_numbers(
|
||||
lhs.shape, (1, 1) + tuple(filter_shape), dimension_numbers)
|
||||
lhs_array.shape, (1, 1) + tuple(filter_shape), dimension_numbers)
|
||||
|
||||
lhs_c_dims, rhs_c_dims = [out_spec[1]], [rhs_spec[1]]
|
||||
|
||||
|
@ -94,7 +94,7 @@ class PRNGImpl(NamedTuple):
|
||||
|
||||
# -- PRNG key arrays
|
||||
|
||||
def _check_prng_key_data(impl, key_data: jnp.ndarray):
|
||||
def _check_prng_key_data(impl, key_data: jax.Array):
|
||||
ndim = len(impl.key_shape)
|
||||
if not all(hasattr(key_data, attr) for attr in ['ndim', 'shape', 'dtype']):
|
||||
raise TypeError("JAX encountered invalid PRNG key data: expected key_data "
|
||||
@ -136,7 +136,7 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
|
||||
"""
|
||||
|
||||
impl: PRNGImpl
|
||||
_base_array: jnp.ndarray
|
||||
_base_array: jax.Array
|
||||
|
||||
def __init__(self, impl, key_data: Any):
|
||||
assert not isinstance(key_data, core.Tracer)
|
||||
@ -794,14 +794,14 @@ mlir.register_lowering(random_unwrap_p, random_unwrap_lowering)
|
||||
# -- threefry2x32 PRNG implementation
|
||||
|
||||
|
||||
def _is_threefry_prng_key(key: jnp.ndarray) -> bool:
|
||||
def _is_threefry_prng_key(key: jax.Array) -> bool:
|
||||
try:
|
||||
return key.shape == (2,) and key.dtype == np.uint32
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
def threefry_seed(seed: jnp.ndarray) -> jnp.ndarray:
|
||||
def threefry_seed(seed: jax.Array) -> jax.Array:
|
||||
"""Create a single raw threefry PRNG key from an integer seed.
|
||||
|
||||
Args:
|
||||
@ -1099,26 +1099,26 @@ def threefry_2x32(keypair, count):
|
||||
return lax.reshape(out[:-1] if odd_size else out, count.shape)
|
||||
|
||||
|
||||
def threefry_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
|
||||
def threefry_split(key: jax.Array, num: int) -> jax.Array:
|
||||
if config.jax_threefry_partitionable:
|
||||
return _threefry_split_foldlike(key, int(num)) # type: ignore
|
||||
else:
|
||||
return _threefry_split_original(key, int(num)) # type: ignore
|
||||
|
||||
@partial(jit, static_argnums=(1,), inline=True)
|
||||
def _threefry_split_original(key, num) -> jnp.ndarray:
|
||||
def _threefry_split_original(key, num) -> jax.Array:
|
||||
counts = lax.iota(np.uint32, num * 2)
|
||||
return lax.reshape(threefry_2x32(key, counts), (num, 2))
|
||||
|
||||
@partial(jit, static_argnums=(1,), inline=True)
|
||||
def _threefry_split_foldlike(key, num) -> jnp.ndarray:
|
||||
def _threefry_split_foldlike(key, num) -> jax.Array:
|
||||
k1, k2 = key
|
||||
counts1, counts2 = iota_2x32_shape((num,))
|
||||
bits1, bits2 = threefry2x32_p.bind(k1, k2, counts1, counts2)
|
||||
return jnp.stack([bits1, bits2], axis=1)
|
||||
|
||||
|
||||
def threefry_fold_in(key: jnp.ndarray, data: jnp.ndarray) -> jnp.ndarray:
|
||||
def threefry_fold_in(key: jax.Array, data: jax.Array) -> jax.Array:
|
||||
assert not data.shape
|
||||
return _threefry_fold_in(key, jnp.uint32(data))
|
||||
|
||||
@ -1127,7 +1127,7 @@ def _threefry_fold_in(key, data):
|
||||
return threefry_2x32(key, threefry_seed(data))
|
||||
|
||||
|
||||
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
|
||||
def threefry_random_bits(key: jax.Array, bit_width, shape):
|
||||
"""Sample uniform random bits of given width and shape using PRNG key."""
|
||||
if not _is_threefry_prng_key(key):
|
||||
raise TypeError("threefry_random_bits got invalid prng key.")
|
||||
@ -1140,7 +1140,7 @@ def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
|
||||
else:
|
||||
return _threefry_random_bits_original(key, bit_width, shape)
|
||||
|
||||
def _threefry_random_bits_partitionable(key: jnp.ndarray, bit_width, shape):
|
||||
def _threefry_random_bits_partitionable(key: jax.Array, bit_width, shape):
|
||||
if all(core.is_constant_dim(d) for d in shape) and prod(shape) > 2 ** 64:
|
||||
raise NotImplementedError('random bits array of size exceeding 2 ** 64')
|
||||
|
||||
@ -1159,7 +1159,7 @@ def _threefry_random_bits_partitionable(key: jnp.ndarray, bit_width, shape):
|
||||
return lax.convert_element_type(bits1 ^ bits2, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(1, 2), inline=True)
|
||||
def _threefry_random_bits_original(key: jnp.ndarray, bit_width, shape):
|
||||
def _threefry_random_bits_original(key: jax.Array, bit_width, shape):
|
||||
size = prod(shape)
|
||||
# Compute ceil(bit_width * size / 32) in a way that is friendly to shape
|
||||
# polymorphism
|
||||
@ -1219,12 +1219,12 @@ threefry_prng_impl = PRNGImpl(
|
||||
# stable/deterministic across backends or compiler versions. Correspondingly, we
|
||||
# reserve the right to change any of these implementations at any time!
|
||||
|
||||
def _rbg_seed(seed: jnp.ndarray) -> jnp.ndarray:
|
||||
def _rbg_seed(seed: jax.Array) -> jax.Array:
|
||||
assert not seed.shape
|
||||
halfkey = threefry_seed(seed)
|
||||
return jnp.concatenate([halfkey, halfkey])
|
||||
|
||||
def _rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
|
||||
def _rbg_split(key: jax.Array, num: int) -> jax.Array:
|
||||
if config.jax_threefry_partitionable:
|
||||
_threefry_split = _threefry_split_foldlike
|
||||
else:
|
||||
@ -1232,12 +1232,12 @@ def _rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
|
||||
return vmap(
|
||||
_threefry_split, (0, None), 1)(key.reshape(2, 2), num).reshape(num, 4)
|
||||
|
||||
def _rbg_fold_in(key: jnp.ndarray, data: jnp.ndarray) -> jnp.ndarray:
|
||||
def _rbg_fold_in(key: jax.Array, data: jax.Array) -> jax.Array:
|
||||
assert not data.shape
|
||||
return vmap(_threefry_fold_in, (0, None), 0)(key.reshape(2, 2), data).reshape(4)
|
||||
|
||||
def _rbg_random_bits(key: jnp.ndarray, bit_width: int, shape: Sequence[int]
|
||||
) -> jnp.ndarray:
|
||||
def _rbg_random_bits(key: jax.Array, bit_width: int, shape: Sequence[int]
|
||||
) -> jax.Array:
|
||||
if not key.shape == (4,) and key.dtype == jnp.dtype('uint32'):
|
||||
raise TypeError("_rbg_random_bits got invalid prng key.")
|
||||
if bit_width not in (8, 16, 32, 64):
|
||||
@ -1253,12 +1253,12 @@ rbg_prng_impl = PRNGImpl(
|
||||
fold_in=_rbg_fold_in,
|
||||
tag='rbg')
|
||||
|
||||
def _unsafe_rbg_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
|
||||
def _unsafe_rbg_split(key: jax.Array, num: int) -> jax.Array:
|
||||
# treat 10 iterations of random bits as a 'hash function'
|
||||
_, keys = lax.rng_bit_generator(key, (10 * num, 4), dtype='uint32')
|
||||
return keys[::10]
|
||||
|
||||
def _unsafe_rbg_fold_in(key: jnp.ndarray, data: jnp.ndarray) -> jnp.ndarray:
|
||||
def _unsafe_rbg_fold_in(key: jax.Array, data: jax.Array) -> jax.Array:
|
||||
assert not data.shape
|
||||
_, random_bits = lax.rng_bit_generator(_rbg_seed(data), (10, 4), dtype='uint32')
|
||||
return key ^ random_bits[-1]
|
||||
|
@ -45,19 +45,19 @@ class _BFGSResults(NamedTuple):
|
||||
line_search_status: int describing line search end state (only means
|
||||
something if line search fails).
|
||||
"""
|
||||
converged: Union[bool, jnp.ndarray]
|
||||
failed: Union[bool, jnp.ndarray]
|
||||
k: Union[int, jnp.ndarray]
|
||||
nfev: Union[int, jnp.ndarray]
|
||||
ngev: Union[int, jnp.ndarray]
|
||||
nhev: Union[int, jnp.ndarray]
|
||||
x_k: jnp.ndarray
|
||||
f_k: jnp.ndarray
|
||||
g_k: jnp.ndarray
|
||||
H_k: jnp.ndarray
|
||||
old_old_fval: jnp.ndarray
|
||||
status: Union[int, jnp.ndarray]
|
||||
line_search_status: Union[int, jnp.ndarray]
|
||||
converged: Union[bool, jax.Array]
|
||||
failed: Union[bool, jax.Array]
|
||||
k: Union[int, jax.Array]
|
||||
nfev: Union[int, jax.Array]
|
||||
ngev: Union[int, jax.Array]
|
||||
nhev: Union[int, jax.Array]
|
||||
x_k: jax.Array
|
||||
f_k: jax.Array
|
||||
g_k: jax.Array
|
||||
H_k: jax.Array
|
||||
old_old_fval: jax.Array
|
||||
status: Union[int, jax.Array]
|
||||
line_search_status: Union[int, jax.Array]
|
||||
|
||||
|
||||
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
|
||||
@ -66,7 +66,7 @@ _einsum = partial(jnp.einsum, precision=lax.Precision.HIGHEST)
|
||||
|
||||
def minimize_bfgs(
|
||||
fun: Callable,
|
||||
x0: jnp.ndarray,
|
||||
x0: jax.Array,
|
||||
maxiter: Optional[int] = None,
|
||||
norm=jnp.inf,
|
||||
gtol: float = 1e-5,
|
||||
|
@ -59,23 +59,23 @@ def _binary_replace(replace_bit, original_dict, new_dict, keys=None):
|
||||
|
||||
|
||||
class _ZoomState(NamedTuple):
|
||||
done: Union[bool, jnp.ndarray]
|
||||
failed: Union[bool, jnp.ndarray]
|
||||
j: Union[int, jnp.ndarray]
|
||||
a_lo: Union[float, jnp.ndarray]
|
||||
phi_lo: Union[float, jnp.ndarray]
|
||||
dphi_lo: Union[float, jnp.ndarray]
|
||||
a_hi: Union[float, jnp.ndarray]
|
||||
phi_hi: Union[float, jnp.ndarray]
|
||||
dphi_hi: Union[float, jnp.ndarray]
|
||||
a_rec: Union[float, jnp.ndarray]
|
||||
phi_rec: Union[float, jnp.ndarray]
|
||||
a_star: Union[float, jnp.ndarray]
|
||||
phi_star: Union[float, jnp.ndarray]
|
||||
dphi_star: Union[float, jnp.ndarray]
|
||||
g_star: Union[float, jnp.ndarray]
|
||||
nfev: Union[int, jnp.ndarray]
|
||||
ngev: Union[int, jnp.ndarray]
|
||||
done: Union[bool, jax.Array]
|
||||
failed: Union[bool, jax.Array]
|
||||
j: Union[int, jax.Array]
|
||||
a_lo: Union[float, jax.Array]
|
||||
phi_lo: Union[float, jax.Array]
|
||||
dphi_lo: Union[float, jax.Array]
|
||||
a_hi: Union[float, jax.Array]
|
||||
phi_hi: Union[float, jax.Array]
|
||||
dphi_hi: Union[float, jax.Array]
|
||||
a_rec: Union[float, jax.Array]
|
||||
phi_rec: Union[float, jax.Array]
|
||||
a_star: Union[float, jax.Array]
|
||||
phi_star: Union[float, jax.Array]
|
||||
dphi_star: Union[float, jax.Array]
|
||||
g_star: Union[float, jax.Array]
|
||||
nfev: Union[int, jax.Array]
|
||||
ngev: Union[int, jax.Array]
|
||||
|
||||
|
||||
def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo,
|
||||
@ -215,18 +215,18 @@ def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo,
|
||||
|
||||
|
||||
class _LineSearchState(NamedTuple):
|
||||
done: Union[bool, jnp.ndarray]
|
||||
failed: Union[bool, jnp.ndarray]
|
||||
i: Union[int, jnp.ndarray]
|
||||
a_i1: Union[float, jnp.ndarray]
|
||||
phi_i1: Union[float, jnp.ndarray]
|
||||
dphi_i1: Union[float, jnp.ndarray]
|
||||
nfev: Union[int, jnp.ndarray]
|
||||
ngev: Union[int, jnp.ndarray]
|
||||
a_star: Union[float, jnp.ndarray]
|
||||
phi_star: Union[float, jnp.ndarray]
|
||||
dphi_star: Union[float, jnp.ndarray]
|
||||
g_star: jnp.ndarray
|
||||
done: Union[bool, jax.Array]
|
||||
failed: Union[bool, jax.Array]
|
||||
i: Union[int, jax.Array]
|
||||
a_i1: Union[float, jax.Array]
|
||||
phi_i1: Union[float, jax.Array]
|
||||
dphi_i1: Union[float, jax.Array]
|
||||
nfev: Union[int, jax.Array]
|
||||
ngev: Union[int, jax.Array]
|
||||
a_star: Union[float, jax.Array]
|
||||
phi_star: Union[float, jax.Array]
|
||||
dphi_star: Union[float, jax.Array]
|
||||
g_star: jax.Array
|
||||
|
||||
|
||||
class _LineSearchResults(NamedTuple):
|
||||
@ -243,15 +243,15 @@ class _LineSearchResults(NamedTuple):
|
||||
g_k: final gradient value
|
||||
status: integer end status
|
||||
"""
|
||||
failed: Union[bool, jnp.ndarray]
|
||||
nit: Union[int, jnp.ndarray]
|
||||
nfev: Union[int, jnp.ndarray]
|
||||
ngev: Union[int, jnp.ndarray]
|
||||
k: Union[int, jnp.ndarray]
|
||||
a_k: Union[int, jnp.ndarray]
|
||||
f_k: jnp.ndarray
|
||||
g_k: jnp.ndarray
|
||||
status: Union[bool, jnp.ndarray]
|
||||
failed: Union[bool, jax.Array]
|
||||
nit: Union[int, jax.Array]
|
||||
nfev: Union[int, jax.Array]
|
||||
ngev: Union[int, jax.Array]
|
||||
k: Union[int, jax.Array]
|
||||
a_k: Union[int, jax.Array]
|
||||
f_k: jax.Array
|
||||
g_k: jax.Array
|
||||
status: Union[bool, jax.Array]
|
||||
|
||||
|
||||
def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,
|
||||
|
@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, Mapping, Optional, Tuple, Union
|
||||
|
||||
import jax
|
||||
from jax._src.scipy.optimize.bfgs import minimize_bfgs
|
||||
from jax._src.scipy.optimize._lbfgs import _minimize_lbfgs
|
||||
from typing import NamedTuple
|
||||
@ -34,20 +36,20 @@ class OptimizeResults(NamedTuple):
|
||||
njev: integer number of gradient evaluations.
|
||||
nit: integer number of iterations of the optimization algorithm.
|
||||
"""
|
||||
x: jnp.ndarray
|
||||
success: Union[bool, jnp.ndarray]
|
||||
status: Union[int, jnp.ndarray]
|
||||
fun: jnp.ndarray
|
||||
jac: jnp.ndarray
|
||||
hess_inv: Optional[jnp.ndarray]
|
||||
nfev: Union[int, jnp.ndarray]
|
||||
njev: Union[int, jnp.ndarray]
|
||||
nit: Union[int, jnp.ndarray]
|
||||
x: jax.Array
|
||||
success: Union[bool, jax.Array]
|
||||
status: Union[int, jax.Array]
|
||||
fun: jax.Array
|
||||
jac: jax.Array
|
||||
hess_inv: Optional[jax.Array]
|
||||
nfev: Union[int, jax.Array]
|
||||
njev: Union[int, jax.Array]
|
||||
nit: Union[int, jax.Array]
|
||||
|
||||
|
||||
def minimize(
|
||||
fun: Callable,
|
||||
x0: jnp.ndarray,
|
||||
x0: jax.Array,
|
||||
args: Tuple = (),
|
||||
*,
|
||||
method: str,
|
||||
|
@ -17,6 +17,8 @@ from functools import partial
|
||||
import operator
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import device_put
|
||||
from jax import lax
|
||||
@ -87,7 +89,7 @@ def _normalize_matvec(f):
|
||||
"""Normalize an argument for computing matrix-vector products."""
|
||||
if callable(f):
|
||||
return f
|
||||
elif isinstance(f, (np.ndarray, jnp.ndarray)):
|
||||
elif isinstance(f, (np.ndarray, jax.Array)):
|
||||
if f.ndim != 2 or f.shape[0] != f.shape[1]:
|
||||
raise ValueError(
|
||||
f'linear operator must be a square matrix, but has shape: {f.shape}')
|
||||
|
@ -16,8 +16,8 @@ from collections import namedtuple
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import scipy
|
||||
from jax import jit
|
||||
from jax._src import dtypes
|
||||
from jax._src.api import vmap
|
||||
@ -26,6 +26,8 @@ from jax._src.numpy.util import _wraps
|
||||
from jax._src.typing import ArrayLike, Array
|
||||
from jax._src.util import canonicalize_axis, prod
|
||||
|
||||
import scipy
|
||||
|
||||
ModeResult = namedtuple('ModeResult', ('mode', 'count'))
|
||||
|
||||
@_wraps(scipy.stats.mode, lax_description="""\
|
||||
@ -68,7 +70,7 @@ def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", k
|
||||
axis = 0
|
||||
x = x.ravel()
|
||||
|
||||
def _mode_helper(x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||
def _mode_helper(x: jax.Array) -> Tuple[jax.Array, jax.Array]:
|
||||
"""Helper function to return mode and count of a given array."""
|
||||
if x.size == 0:
|
||||
return jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)), jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_))
|
||||
|
@ -19,6 +19,8 @@ https://github.com/google/flax/tree/main/examples/ogbg_molpcba
|
||||
from typing import Callable, Sequence
|
||||
|
||||
from flax import linen as nn
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jraph
|
||||
|
||||
@ -38,7 +40,7 @@ class MLP(nn.Module):
|
||||
feature_sizes: Sequence[int]
|
||||
dropout_rate: float = 0
|
||||
deterministic: bool = True
|
||||
activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
|
||||
activation: Callable[[jax.Array], jax.Array] = nn.relu
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs):
|
||||
@ -131,8 +133,8 @@ class GraphConvNet(nn.Module):
|
||||
skip_connections: bool = True
|
||||
layer_norm: bool = True
|
||||
deterministic: bool = True
|
||||
pooling_fn: Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray],
|
||||
jnp.ndarray] = jraph.segment_mean
|
||||
pooling_fn: Callable[[jax.Array, jax.Array, jax.Array],
|
||||
jax.Array] = jraph.segment_mean
|
||||
|
||||
def pool(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
|
||||
"""Pooling operation, taken from Jraph."""
|
||||
|
@ -153,7 +153,7 @@ def _crop_convert_error(msg: str) -> str:
|
||||
return msg
|
||||
|
||||
|
||||
def _get_random_data(x: jnp.ndarray) -> np.ndarray:
|
||||
def _get_random_data(x: jax.Array) -> np.ndarray:
|
||||
dtype = dtypes.canonicalize_dtype(x.dtype)
|
||||
if np.issubdtype(dtype, np.integer):
|
||||
return np.random.randint(0, 100, size=x.shape, dtype=dtype)
|
||||
|
@ -16,15 +16,15 @@
|
||||
import abc
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
import jax.numpy as jnp
|
||||
from jax._src import util
|
||||
from jax._src.typing import Array
|
||||
|
||||
|
||||
class JAXSparse(abc.ABC):
|
||||
"""Base class for high-level JAX sparse objects."""
|
||||
data: jnp.ndarray
|
||||
data: jax.Array
|
||||
shape: Tuple[int, ...]
|
||||
nse: property
|
||||
dtype: property
|
||||
|
@ -23,6 +23,7 @@ from typing import NamedTuple, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import config
|
||||
from jax import lax
|
||||
@ -97,7 +98,7 @@ def _compatible(shape1: Sequence[int], shape2: Sequence[int]) -> bool:
|
||||
return all(s1 in (1, s2) for s1, s2 in safe_zip(shape1, shape2))
|
||||
|
||||
|
||||
def _validate_bcsr_indices(indices: jnp.ndarray, indptr: jnp.ndarray,
|
||||
def _validate_bcsr_indices(indices: jax.Array, indptr: jax.Array,
|
||||
shape: Sequence[int]) -> BCSRProperties:
|
||||
assert jnp.issubdtype(indices.dtype, jnp.integer)
|
||||
assert jnp.issubdtype(indptr.dtype, jnp.integer)
|
||||
@ -118,8 +119,8 @@ def _validate_bcsr_indices(indices: jnp.ndarray, indptr: jnp.ndarray,
|
||||
return BCSRProperties(n_batch=n_batch, n_dense=n_dense, nse=nse)
|
||||
|
||||
|
||||
def _validate_bcsr(data: jnp.ndarray, indices: jnp.ndarray,
|
||||
indptr: jnp.ndarray, shape: Sequence[int]) -> BCSRProperties:
|
||||
def _validate_bcsr(data: jax.Array, indices: jax.Array,
|
||||
indptr: jax.Array, shape: Sequence[int]) -> BCSRProperties:
|
||||
props = _validate_bcsr_indices(indices, indptr, shape)
|
||||
shape = tuple(shape)
|
||||
n_batch, n_dense, nse = props.n_batch, props.n_dense, props.nse
|
||||
@ -134,8 +135,8 @@ def _validate_bcsr(data: jnp.ndarray, indices: jnp.ndarray,
|
||||
return props
|
||||
|
||||
|
||||
def _bcsr_to_bcoo(indices: jnp.ndarray, indptr: jnp.ndarray, *,
|
||||
shape: Sequence[int]) -> jnp.ndarray:
|
||||
def _bcsr_to_bcoo(indices: jax.Array, indptr: jax.Array, *,
|
||||
shape: Sequence[int]) -> jax.Array:
|
||||
"""Given BCSR (indices, indptr), return BCOO (indices)."""
|
||||
n_batch, _, _ = _validate_bcsr_indices(indices, indptr, shape)
|
||||
csr_to_coo = nfold_vmap(_csr_to_coo, n_batch)
|
||||
@ -478,8 +479,8 @@ def bcsr_dot_general(lhs: Union[BCSR, Array], rhs: Array, *,
|
||||
dense, the result will be dense, of type ndarray.
|
||||
"""
|
||||
del precision, preferred_element_type # unused
|
||||
if isinstance(rhs, (np.ndarray, jnp.ndarray)):
|
||||
if isinstance(lhs, (np.ndarray, jnp.ndarray)):
|
||||
if isinstance(rhs, (np.ndarray, jax.Array)):
|
||||
if isinstance(lhs, (np.ndarray, jax.Array)):
|
||||
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)
|
||||
|
||||
if isinstance(lhs, BCSR):
|
||||
@ -492,8 +493,8 @@ def bcsr_dot_general(lhs: Union[BCSR, Array], rhs: Array, *,
|
||||
"lhs and ndarray rhs.")
|
||||
|
||||
|
||||
def _bcsr_dot_general(lhs_data: jnp.ndarray, lhs_indices: jnp.ndarray,
|
||||
lhs_indptr: jnp.ndarray, rhs: Array, *,
|
||||
def _bcsr_dot_general(lhs_data: jax.Array, lhs_indices: jax.Array,
|
||||
lhs_indptr: jax.Array, rhs: Array, *,
|
||||
dimension_numbers: DotDimensionNumbers,
|
||||
lhs_spinfo: SparseInfo) -> Array:
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
@ -708,9 +709,9 @@ def bcsr_broadcast_in_dim(mat: BCSR, *, shape: Shape, broadcast_dimensions: Sequ
|
||||
class BCSR(JAXSparse):
|
||||
"""Experimental batched CSR matrix implemented in JAX."""
|
||||
|
||||
data: jnp.ndarray
|
||||
indices: jnp.ndarray
|
||||
indptr: jnp.ndarray
|
||||
data: jax.Array
|
||||
indices: jax.Array
|
||||
indptr: jax.Array
|
||||
shape: Shape
|
||||
nse = property(lambda self: self.indices.shape[-1])
|
||||
dtype = property(lambda self: self.data.dtype)
|
||||
|
@ -22,6 +22,7 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.interpreters import mlir
|
||||
from jax.experimental.sparse._base import JAXSparse
|
||||
@ -54,9 +55,9 @@ class COO(JAXSparse):
|
||||
grad and autodiff, and offers very little functionality. In general you
|
||||
should prefer :class:`jax.experimental.sparse.BCOO`.
|
||||
"""
|
||||
data: jnp.ndarray
|
||||
row: jnp.ndarray
|
||||
col: jnp.ndarray
|
||||
data: jax.Array
|
||||
row: jax.Array
|
||||
col: jax.Array
|
||||
shape: Tuple[int, int]
|
||||
nse = property(lambda self: self.data.size)
|
||||
dtype = property(lambda self: self.data.dtype)
|
||||
|
@ -21,6 +21,7 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.interpreters import mlir
|
||||
from jax.experimental.sparse._base import JAXSparse
|
||||
from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo
|
||||
@ -43,9 +44,9 @@ class CSR(JAXSparse):
|
||||
grad and autodiff, and offers very little functionality. In general you
|
||||
should prefer :class:`jax.experimental.sparse.BCOO`.
|
||||
"""
|
||||
data: jnp.ndarray
|
||||
indices: jnp.ndarray
|
||||
indptr: jnp.ndarray
|
||||
data: jax.Array
|
||||
indices: jax.Array
|
||||
indptr: jax.Array
|
||||
shape: Tuple[int, int]
|
||||
nse = property(lambda self: self.data.size)
|
||||
dtype = property(lambda self: self.data.dtype)
|
||||
@ -129,9 +130,9 @@ class CSR(JAXSparse):
|
||||
@tree_util.register_pytree_node_class
|
||||
class CSC(JAXSparse):
|
||||
"""Experimental CSC matrix implemented in JAX; API subject to change."""
|
||||
data: jnp.ndarray
|
||||
indices: jnp.ndarray
|
||||
indptr: jnp.ndarray
|
||||
data: jax.Array
|
||||
indices: jax.Array
|
||||
indptr: jax.Array
|
||||
shape: Tuple[int, int]
|
||||
nse = property(lambda self: self.data.size)
|
||||
dtype = property(lambda self: self.data.dtype)
|
||||
|
@ -31,10 +31,10 @@ from scipy.sparse import csr_matrix, linalg
|
||||
|
||||
|
||||
def lobpcg_standard(
|
||||
A: Union[jnp.ndarray, Callable[[jnp.ndarray], jnp.ndarray]],
|
||||
X: jnp.ndarray,
|
||||
A: Union[jax.Array, Callable[[jax.Array], jax.Array]],
|
||||
X: jax.Array,
|
||||
m: int = 100,
|
||||
tol: Union[jnp.ndarray, float, None] = None):
|
||||
tol: Union[jax.Array, float, None] = None):
|
||||
"""Compute the top-k standard eigenvalues using the LOBPCG routine.
|
||||
|
||||
LOBPCG [1] stands for Locally Optimal Block Preconditioned Conjugate Gradient.
|
||||
@ -95,16 +95,16 @@ def lobpcg_standard(
|
||||
large (only `k * 5 < n` supported), or `k == 0`.
|
||||
"""
|
||||
# Jit-compile once per matrix shape if possible.
|
||||
if isinstance(A, (jnp.ndarray, np.ndarray)):
|
||||
if isinstance(A, (jax.Array, np.ndarray)):
|
||||
return _lobpcg_standard_matrix(A, X, m, tol, debug=False)
|
||||
return _lobpcg_standard_callable(A, X, m, tol, debug=False)
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=['m', 'debug'])
|
||||
def _lobpcg_standard_matrix(
|
||||
A: jnp.ndarray,
|
||||
X: jnp.ndarray,
|
||||
A: jax.Array,
|
||||
X: jax.Array,
|
||||
m: int,
|
||||
tol: Union[jnp.ndarray, float, None],
|
||||
tol: Union[jax.Array, float, None],
|
||||
debug: bool = False):
|
||||
"""Computes lobpcg_standard(), possibly with debug diagnostics."""
|
||||
return _lobpcg_standard_callable(
|
||||
@ -112,10 +112,10 @@ def _lobpcg_standard_matrix(
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=['A', 'm', 'debug'])
|
||||
def _lobpcg_standard_callable(
|
||||
A: Callable[[jnp.ndarray], jnp.ndarray],
|
||||
X: jnp.ndarray,
|
||||
A: Callable[[jax.Array], jax.Array],
|
||||
X: jax.Array,
|
||||
m: int,
|
||||
tol: Union[jnp.ndarray, float, None],
|
||||
tol: Union[jax.Array, float, None],
|
||||
debug: bool = False):
|
||||
"""Supports generic lobpcg_standard() callable interface."""
|
||||
|
||||
|
@ -1474,11 +1474,11 @@ class APITest(jtu.JaxTestCase):
|
||||
"Transpose rule (for reverse-mode differentiation) for 'foo' not implemented")
|
||||
|
||||
def test_is_subclass(self):
|
||||
self.assertTrue(issubclass(device_array.DeviceArray, jnp.ndarray))
|
||||
self.assertTrue(issubclass(device_array.Buffer, jnp.ndarray))
|
||||
self.assertTrue(issubclass(pxla.ShardedDeviceArray, jnp.ndarray))
|
||||
self.assertTrue(issubclass(pxla._ShardedDeviceArray, jnp.ndarray))
|
||||
self.assertFalse(issubclass(np.ndarray, jnp.ndarray))
|
||||
self.assertTrue(issubclass(device_array.DeviceArray, jax.Array))
|
||||
self.assertTrue(issubclass(device_array.Buffer, jax.Array))
|
||||
self.assertTrue(issubclass(pxla.ShardedDeviceArray, jax.Array))
|
||||
self.assertTrue(issubclass(pxla._ShardedDeviceArray, jax.Array))
|
||||
self.assertFalse(issubclass(np.ndarray, jax.Array))
|
||||
self.assertFalse(issubclass(device_array.DeviceArray, np.ndarray))
|
||||
self.assertFalse(issubclass(device_array.Buffer, np.ndarray))
|
||||
self.assertFalse(issubclass(pxla.ShardedDeviceArray, np.ndarray))
|
||||
@ -1486,7 +1486,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_is_instance(self):
|
||||
def f(x):
|
||||
self.assertIsInstance(x, jnp.ndarray)
|
||||
self.assertIsInstance(x, jax.Array)
|
||||
self.assertNotIsInstance(x, np.ndarray)
|
||||
return x + 2
|
||||
jit(f)(3)
|
||||
@ -1496,10 +1496,10 @@ class APITest(jtu.JaxTestCase):
|
||||
x = np.arange(12.).reshape((3, 4)).astype("float32")
|
||||
dx = api.device_put(x)
|
||||
_check_instance(self, dx)
|
||||
self.assertIsInstance(dx, jnp.ndarray)
|
||||
self.assertIsInstance(dx, jax.Array)
|
||||
self.assertNotIsInstance(dx, np.ndarray)
|
||||
x2 = api.device_get(dx)
|
||||
self.assertNotIsInstance(x2, jnp.ndarray)
|
||||
self.assertNotIsInstance(x2, jax.Array)
|
||||
self.assertIsInstance(x2, np.ndarray)
|
||||
assert np.all(x == x2)
|
||||
|
||||
@ -7114,7 +7114,7 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
raise unittest.SkipTest("test only applies when x64 is disabled")
|
||||
|
||||
@jax.custom_jvp
|
||||
def projection_unit_simplex(x: jnp.ndarray) -> jnp.ndarray:
|
||||
def projection_unit_simplex(x: jax.Array) -> jax.Array:
|
||||
"""Projection onto the unit simplex."""
|
||||
s = 1.0
|
||||
n_features = x.shape[0]
|
||||
|
@ -762,7 +762,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
|
||||
def test_is_subclass(self):
|
||||
# array version of api_test.py::APITest::test_is_subclass
|
||||
self.assertTrue(issubclass(array.ArrayImpl, jnp.ndarray))
|
||||
self.assertTrue(issubclass(array.ArrayImpl, jax.Array))
|
||||
self.assertFalse(issubclass(array.ArrayImpl, np.ndarray))
|
||||
|
||||
def test_op_sharding_sharding_repr(self):
|
||||
|
@ -98,7 +98,7 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
expected_dtype = dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[type_])
|
||||
for f in [jnp.array, jax.jit(jnp.array), jax.jit(lambda x: x)]:
|
||||
y = f(type_(0))
|
||||
self.assertTrue(isinstance(y, jnp.ndarray), msg=(f, y))
|
||||
self.assertTrue(isinstance(y, jax.Array), msg=(f, y))
|
||||
self.assertEqual(y.dtype, expected_dtype, msg=(f, y))
|
||||
|
||||
def testUnsupportedType(self):
|
||||
@ -141,7 +141,7 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
for x, y, dtype in testcases:
|
||||
x, y = (y, x) if swap else (x, y)
|
||||
z = op(x, y)
|
||||
self.assertTrue(isinstance(z, jnp.ndarray), msg=(x, y, z))
|
||||
self.assertTrue(isinstance(z, jax.Array), msg=(x, y, z))
|
||||
self.assertEqual(z.dtype, dtypes.canonicalize_dtype(dtype), msg=(x, y, z))
|
||||
|
||||
@jax.numpy_dtype_promotion('strict')
|
||||
|
@ -160,7 +160,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
if dtype == dtypes.canonicalize_dtype(dtype)])
|
||||
def testDtypeWrappers(self, dtype):
|
||||
arr = dtype(0)
|
||||
self.assertIsInstance(arr, jnp.ndarray)
|
||||
self.assertIsInstance(arr, jax.Array)
|
||||
self.assertEqual(arr.dtype, np.dtype(dtype))
|
||||
self.assertArraysEqual(arr, 0, check_dtypes=False)
|
||||
|
||||
@ -3183,7 +3183,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
self.assertIsInstance(x, jnp.ndarray)
|
||||
self.assertIsInstance(x, jax.Array)
|
||||
return jnp.sum(x)
|
||||
|
||||
f(arr)
|
||||
@ -4833,7 +4833,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
check_dtypes=False)
|
||||
|
||||
def testBroadcastToOnScalar(self):
|
||||
self.assertIsInstance(jnp.broadcast_to(10.0, ()), jnp.ndarray)
|
||||
self.assertIsInstance(jnp.broadcast_to(10.0, ()), jax.Array)
|
||||
self.assertIsInstance(np.broadcast_to(10.0, ()), np.ndarray)
|
||||
|
||||
def testPrecision(self):
|
||||
|
@ -2951,7 +2951,7 @@ def take_abstract_eval(x):
|
||||
|
||||
class FooArray:
|
||||
shape: Tuple[int, ...]
|
||||
data: jnp.ndarray
|
||||
data: jax.Array
|
||||
|
||||
def __init__(self, shape, data):
|
||||
assert data.shape == (*shape, 2)
|
||||
|
@ -274,7 +274,7 @@ class LobpcgTest(jtu.JaxTestCase):
|
||||
if not os.getenv('LOBPCG_EMIT_DEBUG_PLOTS'):
|
||||
return
|
||||
|
||||
if isinstance(A, (np.ndarray, jnp.ndarray)):
|
||||
if isinstance(A, (np.ndarray, jax.Array)):
|
||||
lobpcg = linalg._lobpcg_standard_matrix
|
||||
else:
|
||||
lobpcg = linalg._lobpcg_standard_callable
|
||||
|
@ -775,7 +775,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
# test that we can pass in and out ShardedDeviceArrays
|
||||
y = f(x)
|
||||
self.assertIsInstance(y, jnp.ndarray)
|
||||
self.assertIsInstance(y, jax.Array)
|
||||
if config.jax_array:
|
||||
self.assertIsInstance(y, array.ArrayImpl)
|
||||
else:
|
||||
@ -1655,7 +1655,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
y = f(x)
|
||||
self.assertIsInstance(y, jnp.ndarray)
|
||||
self.assertIsInstance(y, jax.Array)
|
||||
if config.jax_array:
|
||||
self.assertIsInstance(y, array.ArrayImpl)
|
||||
else:
|
||||
|
@ -39,7 +39,7 @@ _QDWH_TEST_EPS = jnp.finfo(_QDWH_TEST_DTYPE).eps
|
||||
_MAX_LOG_CONDITION_NUM = np.log10(int(1 / _QDWH_TEST_EPS))
|
||||
|
||||
|
||||
def _check_symmetry(x: jnp.ndarray) -> bool:
|
||||
def _check_symmetry(x: jax.Array) -> bool:
|
||||
"""Check if the array is symmetric."""
|
||||
m, n = x.shape
|
||||
eps = jnp.finfo(x.dtype).eps
|
||||
|
@ -506,7 +506,7 @@ class PrngTest(jtu.JaxTestCase):
|
||||
if not config.jax_enable_custom_prng:
|
||||
self.skipTest("test requires config.jax_enable_custom_prng")
|
||||
key = random.PRNGKey(0)
|
||||
self.assertIsInstance(key, jnp.ndarray)
|
||||
self.assertIsInstance(key, jax.Array)
|
||||
|
||||
|
||||
class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
@ -2381,7 +2381,7 @@ class SparseObjectTest(sptu.SparseTestCase):
|
||||
|
||||
with self.subTest('to_elt'):
|
||||
M_out = vmap(to_elt)(Msp)
|
||||
self.assertIsInstance(M_out, jnp.ndarray)
|
||||
self.assertIsInstance(M_out, jax.Array)
|
||||
self.assertEqual(Msp.shape, M_out.shape)
|
||||
|
||||
with self.subTest('axis_None'):
|
||||
|
Loading…
x
Reference in New Issue
Block a user