mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Switch internal users of jax.ops.index_... to use x.at[x].set() APIs.
This commit is contained in:
parent
4d68a79921
commit
a84426cb8f
@ -126,6 +126,7 @@ from . import errors as errors
|
||||
from . import image as image
|
||||
from . import lax as lax
|
||||
from . import nn as nn
|
||||
from . import ops as ops
|
||||
from . import profiler as profiler
|
||||
from . import random as random
|
||||
from . import tree_util as tree_util
|
||||
|
@ -22,7 +22,6 @@ from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax import lax
|
||||
from jax import ops
|
||||
from jax._src import dtypes
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import ad
|
||||
@ -751,8 +750,8 @@ def _lu_pivots_body_fn(i, permutation_and_swaps):
|
||||
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims))
|
||||
x = permutation[..., i]
|
||||
y = permutation[iotas + (j,)]
|
||||
permutation = ops.index_update(permutation, ops.index[..., i], y)
|
||||
return ops.index_update(permutation, ops.index[iotas + (j,)], x), swaps
|
||||
permutation = permutation.at[..., i].set(y)
|
||||
return permutation.at[iotas + (j,)].set(x), swaps
|
||||
|
||||
|
||||
@partial(api.jit, static_argnums=(1,))
|
||||
@ -853,16 +852,13 @@ def _lu_unblocked(a):
|
||||
else:
|
||||
magnitude = jnp.abs(a[:, k])
|
||||
i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf))
|
||||
pivot = ops.index_update(pivot, ops.index[k], i)
|
||||
|
||||
a = ops.index_update(a, ops.index[[k, i],], a[[i, k],])
|
||||
|
||||
perm = ops.index_update(perm, ops.index[[i, k],], perm[[k, i],])
|
||||
pivot = pivot.at[k].set(i)
|
||||
a = a.at[[k, i],].set(a[[i, k],])
|
||||
perm = perm.at[[i, k],].set(perm[[k, i],])
|
||||
|
||||
# a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
|
||||
x = a[k, k]
|
||||
a = ops.index_update(a, ops.index[:, k],
|
||||
jnp.where(m_idx > k, a[:, k] / x, a[:, k]))
|
||||
a = a.at[:, k].set(jnp.where(m_idx > k, a[:, k] / x, a[:, k]))
|
||||
|
||||
# a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:])
|
||||
a = a - jnp.where((m_idx[:, None] > k) & (n_idx > k),
|
||||
@ -888,20 +884,17 @@ def _lu_blocked(a, block_size=128):
|
||||
b = min(r - k, block_size)
|
||||
block_pivot, block_perm, lu_block = _lu_unblocked(a[k:, k:k+b])
|
||||
|
||||
pivot = ops.index_update(pivot, ops.index[k:k+b], block_pivot + k)
|
||||
perm = ops.index_update(perm, ops.index[k:], perm[block_perm + k])
|
||||
a = ops.index_update(a, ops.index[k:, :], a[block_perm + k, :])
|
||||
a = ops.index_update(a, ops.index[k:, k:k+b], lu_block)
|
||||
pivot = pivot.at[k:k+b].set(block_pivot + k)
|
||||
perm = perm.at[k:].set(perm[block_perm + k])
|
||||
a = a.at[k:, :].set(a[block_perm + k, :])
|
||||
a = a.at[k:, k:k+b].set(lu_block)
|
||||
|
||||
if k + b < n:
|
||||
a = ops.index_update(
|
||||
a, ops.index[k:k+b, k+b:],
|
||||
triangular_solve(a[k:k+b, k:k+b], a[k:k+b, k+b:],
|
||||
left_side=True, lower=True, unit_diagonal=True))
|
||||
a = ops.index_add(
|
||||
a, ops.index[k+b:, k+b:],
|
||||
-lax.dot(a[k+b:, k:k+b], a[k:k+b, k+b:],
|
||||
precision=lax.Precision.HIGHEST))
|
||||
a = a.at[k:k+b, k+b:].set(
|
||||
triangular_solve(a[k:k+b, k:k+b], a[k:k+b, k+b:], left_side=True,
|
||||
lower=True, unit_diagonal=True))
|
||||
a = a.at[k+b:, k+b:].add(-lax.dot(a[k+b:, k:k+b], a[k:k+b, k+b:],
|
||||
precision=lax.Precision.HIGHEST))
|
||||
return a, pivot, perm
|
||||
|
||||
def _lu_python(x):
|
||||
|
@ -32,7 +32,7 @@ import jax.scipy as jsp
|
||||
def _add_to_diagonal(X, val):
|
||||
new_diagonal = X.diagonal() + val
|
||||
diag_indices = jnp.diag_indices(X.shape[0])
|
||||
return jax.ops.index_update(X, diag_indices, new_diagonal)
|
||||
return X.at[diag_indices].set(new_diagonal)
|
||||
|
||||
|
||||
@jax.jit
|
||||
|
@ -24,7 +24,6 @@ import numpy as np
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax import ops
|
||||
from jax import random
|
||||
from jax import core
|
||||
from jax._src.util import prod
|
||||
@ -186,12 +185,11 @@ def delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float_):
|
||||
W = jnp.zeros(shape, dtype=dtype)
|
||||
if len(shape) == 3:
|
||||
k = shape[0]
|
||||
return ops.index_update(W, ops.index[(k-1)//2, ...], ortho_matrix)
|
||||
return W.at[(k-1)//2, ...].set(ortho_matrix)
|
||||
elif len(shape) == 4:
|
||||
k1, k2 = shape[:2]
|
||||
return ops.index_update(W, ops.index[(k1-1)//2, (k2-1)//2, ...], ortho_matrix)
|
||||
return W.at[(k1-1)//2, (k2-1)//2, ...].set(ortho_matrix)
|
||||
else:
|
||||
k1, k2, k3 = shape[:3]
|
||||
return ops.index_update(W, ops.index[(k1-1)//2, (k2-1)//2, (k3-1)//2, ...],
|
||||
ortho_matrix)
|
||||
return W.at[(k1-1)//2, (k2-1)//2, (k3-1)//2, ...].set(ortho_matrix)
|
||||
return init
|
||||
|
@ -21,7 +21,6 @@ from jax.lib import xla_client
|
||||
from jax._src.util import safe_zip
|
||||
from .util import _wraps
|
||||
from . import lax_numpy as jnp
|
||||
from jax import ops as jaxops
|
||||
|
||||
|
||||
def _fft_core(func_name, fft_type, a, s, axes, norm):
|
||||
@ -203,19 +202,17 @@ def fftfreq(n, d=1.0):
|
||||
k = jnp.zeros(n)
|
||||
if n % 2 == 0:
|
||||
# k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1)
|
||||
k = jaxops.index_update(k, jaxops.index[0: n // 2], jnp.arange(0, n // 2))
|
||||
k = k.at[0: n // 2].set( jnp.arange(0, n // 2))
|
||||
|
||||
# k[n // 2:] = jnp.arange(-n // 2, -1)
|
||||
k = jaxops.index_update(k, jaxops.index[n // 2:], jnp.arange(-n // 2, 0))
|
||||
k = k.at[n // 2:].set( jnp.arange(-n // 2, 0))
|
||||
|
||||
else:
|
||||
# k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2)
|
||||
k = jaxops.index_update(k, jaxops.index[0: (n - 1) // 2 + 1],
|
||||
jnp.arange(0, (n - 1) // 2 + 1))
|
||||
k = k.at[0: (n - 1) // 2 + 1].set(jnp.arange(0, (n - 1) // 2 + 1))
|
||||
|
||||
# k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1)
|
||||
k = jaxops.index_update(k, jaxops.index[(n - 1) // 2 + 1:],
|
||||
jnp.arange(-(n - 1) // 2, 0))
|
||||
k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0))
|
||||
|
||||
return k / (d * n)
|
||||
|
||||
|
@ -50,7 +50,6 @@ from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray
|
||||
from jax.interpreters import pxla
|
||||
from jax import lax
|
||||
from jax._src.lax.lax import _device_put_raw
|
||||
from jax import ops
|
||||
from jax._src.ops import scatter
|
||||
from jax._src.util import (partial, unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
|
||||
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
|
||||
@ -4035,10 +4034,8 @@ def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None):
|
||||
# Cumsum to get indices of new number in repeated tensor, e.g. [0, 1, 3, 3]
|
||||
scatter_indices = cumsum(exclusive_repeats)
|
||||
# Scatter these onto a zero buffer, e.g. [1,1,0,2,0,0,0,0]
|
||||
block_split_indicators = ops.index_add(
|
||||
x=zeros([total_repeat_length], dtype=int32),
|
||||
idx=scatter_indices,
|
||||
y=1)
|
||||
block_split_indicators = zeros([total_repeat_length], dtype=int32)
|
||||
block_split_indicators = block_split_indicators.at[scatter_indices].add(1)
|
||||
# Cumsum again to get scatter indices for repeat, e.g. [0,1,1,3,3,3,3,3]
|
||||
gather_indices = cumsum(block_split_indicators) - 1
|
||||
return take(a, gather_indices, axis=axis)
|
||||
@ -4217,7 +4214,7 @@ def diagflat(v, k=0):
|
||||
fi = i+k+i*adj_length
|
||||
else:
|
||||
fi = i+(i-k)*adj_length
|
||||
res = ops.index_update(res, ops.index[fi], v)
|
||||
res = res.at[fi].set(v)
|
||||
res = res.reshape(adj_length, adj_length)
|
||||
return res
|
||||
|
||||
|
@ -22,7 +22,6 @@ from typing import Tuple, Union, cast
|
||||
|
||||
from jax import jit, custom_jvp
|
||||
from jax import lax
|
||||
from jax import ops
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src import dtypes
|
||||
from .util import _wraps
|
||||
@ -218,7 +217,7 @@ def _cofactor_solve(a, b):
|
||||
# partial_det[:, -1] contains the full determinant and
|
||||
# partial_det[:, -2] contains det(u) / u_{nn}.
|
||||
partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
|
||||
lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2])
|
||||
lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2])
|
||||
permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1],))
|
||||
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,)))
|
||||
# filter out any matrices that are not full rank
|
||||
|
@ -20,7 +20,6 @@ from . import lax_numpy as jnp
|
||||
from jax import jit
|
||||
from .util import _wraps
|
||||
from .linalg import eigvals as _eigvals
|
||||
from jax import ops as jaxops
|
||||
|
||||
|
||||
def _to_inexact_type(type):
|
||||
@ -38,7 +37,7 @@ def _roots_no_zeros(p):
|
||||
|
||||
# build companion matrix and find its eigenvalues (the roots)
|
||||
A = jnp.diag(jnp.ones((p.size - 2,), p.dtype), -1)
|
||||
A = jaxops.index_update(A, jaxops.index[0, :], -p[1:] / p[0])
|
||||
A = A.at[0, :].set(-p[1:] / p[0])
|
||||
roots = _eigvals(A)
|
||||
return roots
|
||||
|
||||
|
@ -138,7 +138,7 @@ def split_spectrum(H, split_point, V0=None, precision=lax.Precision.HIGHEST):
|
||||
Vp: An isometry from the input space of `V0` to `Hp`.
|
||||
"""
|
||||
def _fill_diagonal(X, vals):
|
||||
return jax.ops.index_update(X, jnp.diag_indices(X.shape[0]), vals)
|
||||
return X.at[jnp.diag_indices(X.shape[0])].set(vals)
|
||||
|
||||
H_shift = _fill_diagonal(H, H.diagonal() - split_point)
|
||||
U, _ = jsp.linalg.polar_unitary(H_shift)
|
||||
|
@ -17,7 +17,7 @@ from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax, ops
|
||||
from jax import lax
|
||||
from .line_search import line_search
|
||||
|
||||
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
|
||||
@ -206,7 +206,7 @@ def _two_loop_recursion(state: LBFGSResults):
|
||||
i = his_size - 1 - j
|
||||
_q, _a_his = carry
|
||||
a_i = state.rho_history[i] * jnp.real(_dot(jnp.conj(state.s_history[i]), _q))
|
||||
_a_his = ops.index_update(_a_his, ops.index[i], a_i)
|
||||
_a_his = _a_his.at[i].set(a_i)
|
||||
_q = _q - a_i * jnp.conj(state.y_history[i])
|
||||
return _q, _a_his
|
||||
|
||||
@ -225,9 +225,9 @@ def _two_loop_recursion(state: LBFGSResults):
|
||||
|
||||
def _update_history_vectors(history, new):
|
||||
# TODO(Jakob-Unfried) use rolling buffer instead? See #6053
|
||||
return ops.index_update(jnp.roll(history, -1, axis=0), ops.index[-1, :], new)
|
||||
return jnp.roll(history, -1, axis=0).at[-1, :].set(new)
|
||||
|
||||
|
||||
def _update_history_scalars(history, new):
|
||||
# TODO(Jakob-Unfried) use rolling buffer instead? See #6053
|
||||
return ops.index_update(jnp.roll(history, -1, axis=0), ops.index[-1], new)
|
||||
return jnp.roll(history, -1, axis=0).at[-1].set(new)
|
||||
|
@ -20,7 +20,6 @@ import scipy.special as osp_special
|
||||
from jax._src import api
|
||||
from jax import jit
|
||||
from jax import lax, core
|
||||
from jax import ops
|
||||
from jax.interpreters import ad
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
|
||||
@ -765,14 +764,14 @@ def _gen_derivatives(p: jnp.ndarray,
|
||||
p_p1 = p[1, 1:num_l - 1, :]
|
||||
coeff = -1.0 / ((l_vec + 1) * l_vec)
|
||||
update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1)
|
||||
p_mm2_lm1 = p_mm2_lm1.at[ops.index[1, 2:num_l, :]].set(update_p_p1)
|
||||
p_mm2_lm1 = p_mm2_lm1.at[1, 2:num_l, :].set(update_p_p1)
|
||||
|
||||
if num_l > 2:
|
||||
l_vec = jnp.arange(2, num_l - 1)
|
||||
p_p2 = p[2, 2:num_l - 1, :]
|
||||
coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec)
|
||||
update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2)
|
||||
p_mm2_lm1 = p_mm2_lm1.at[ops.index[0, 3:num_l, :]].set(update_p_p2)
|
||||
p_mm2_lm1 = p_mm2_lm1.at[0, 3:num_l, :].set(update_p_p2)
|
||||
|
||||
m_mat, l_mat = jnp.mgrid[:num_m, :num_l]
|
||||
|
||||
|
@ -279,17 +279,13 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
dict(testcase_name=f"_{op.__name__}", op=op) for op in (
|
||||
jax.ops.index_add,
|
||||
jax.ops.index_max,
|
||||
jax.ops.index_min,
|
||||
jax.ops.index_mul,
|
||||
jax.ops.index_update,
|
||||
dict(testcase_name=f"_{op}", op=op) for op in (
|
||||
"add", "max", "min", "multiply", "set"
|
||||
)))
|
||||
def test_scatter_static(self, op):
|
||||
values = np.ones((5, 6), dtype=np.float32)
|
||||
update = np.float32(6.)
|
||||
f_jax = jax.jit(lambda v, u: op(v, jax.ops.index[::2, 3:], u))
|
||||
f_jax = jax.jit(lambda v, u: getattr(v.at[::2, 3:], op)(u))
|
||||
self.ConvertAndCompare(f_jax, values, update)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
|
@ -34,10 +34,10 @@ returns an updated array, e.g.::
|
||||
|
||||
arr = np.zeros(5)
|
||||
def loop_body(i, acc_arr):
|
||||
arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
|
||||
arr1 = acc_arr.at[i].set(acc_arr[i] + 2.)
|
||||
return lax.cond(i % 2 == 0,
|
||||
arr1,
|
||||
lambda arr1: ops.index_update(arr1, i, arr1[i] + 1),
|
||||
lambda arr1: arr1.at[i].set(arr1[i] + 1),
|
||||
arr1,
|
||||
lambda arr1: arr1)
|
||||
arr = lax.fori_loop(0, arr.shape[0], loop_body, arr)
|
||||
@ -52,9 +52,9 @@ special `loops.scope` object and use `for` loops over special
|
||||
with loops.Scope() as s:
|
||||
s.arr = np.zeros(5) # Create the mutable state of the loop as `scope` fields.
|
||||
for i in s.range(s.arr.shape[0]):
|
||||
s.arr = ops.index_update(s.arr, i, s.arr[i] + 2.)
|
||||
s.arr = s.arr.at[i].set(s.arr[i] + 2.)
|
||||
for _ in s.cond_range(i % 2 == 0): # Conditionals as loops with 0 or 1 iterations
|
||||
s.arr = ops.index_update(s.arr, i, s.arr[i] + 1.)
|
||||
s.arr = s.arr.at[i].set(s.arr[i] + 1.)
|
||||
|
||||
Loops constructed with `range` must have literal constant bounds. If you need
|
||||
loops with dynamic bounds, you can use the more general `while_range` iterator.
|
||||
|
@ -28,7 +28,6 @@ from jax import random
|
||||
from jax import jit, grad, jvp, vjp, make_jaxpr, jacfwd, jacrev, hessian
|
||||
from jax import vmap
|
||||
from jax._src.util import partial
|
||||
import jax.ops
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -936,7 +935,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
self.assertEqual((), empty_tuple)
|
||||
|
||||
def testIndexAddBatchedIndexesOnly(self):
|
||||
f = lambda x, idx, y: jax.ops.index_add(x, jax.ops.index[idx], y)
|
||||
f = lambda x, idx, y: jnp.asarray(x).at[idx].add(y)
|
||||
result = vmap(f, (None, 0, None))(np.zeros((10,)), np.arange(10,), 1.)
|
||||
self.assertAllClose(result, np.eye(10), check_dtypes=False)
|
||||
|
||||
|
@ -948,7 +948,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
n = x.shape[0]
|
||||
y = np.arange(n, dtype=x.dtype)
|
||||
return jax.ops.index_update(x, np.diag_indices(n), y)
|
||||
return jax.device_put(x).at[np.diag_indices(n)].set(y)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
check_grads(f, (rng((5, 5), np.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2,
|
||||
1.)
|
||||
|
@ -1711,10 +1711,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
def jnp_pad_with(vector, pad_width, iaxis, kwargs):
|
||||
pad_value = kwargs.get('padder', 10)
|
||||
vector = jax.ops.index_update(
|
||||
vector, jax.ops.index[:pad_width[0]], pad_value)
|
||||
vector = jax.ops.index_update(
|
||||
vector, jax.ops.index[-pad_width[1]:], pad_value)
|
||||
vector = vector.at[:pad_width[0]].set(pad_value)
|
||||
vector = vector.at[-pad_width[1]:].set(pad_value)
|
||||
return vector
|
||||
|
||||
arr = np.arange(6).reshape(2, 3)
|
||||
|
@ -20,7 +20,7 @@ import numpy as np
|
||||
import re
|
||||
|
||||
import jax
|
||||
from jax import lax, ops
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax.experimental import loops
|
||||
@ -87,7 +87,7 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
assert n == y.shape[0]
|
||||
s.out = jnp.zeros(shape=[n], dtype=jnp.float32)
|
||||
for i in s.range(n):
|
||||
s.out = ops.index_add(s.out, i, x[i] + y[i])
|
||||
s.out = s.out.at[i].add(x[i] + y[i])
|
||||
return s.out
|
||||
|
||||
x = jnp.array([1., 2., 3.], dtype=jnp.float32)
|
||||
@ -104,7 +104,7 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
for i in s.range(n):
|
||||
for j in s.range(p):
|
||||
for k in s.range(m):
|
||||
s.out = ops.index_add(s.out, (i, j), x[i, k] * y[k, j])
|
||||
s.out = s.out.at[(i, j)].add(x[i, k] * y[k, j])
|
||||
return s.out
|
||||
|
||||
x = jnp.array([[1., 2., 3.]], dtype=jnp.float32) # 1x3
|
||||
@ -178,10 +178,10 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
def f_op_jax():
|
||||
arr = jnp.zeros(5)
|
||||
def loop_body(i, acc_arr):
|
||||
arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
|
||||
arr1 = acc_arr.at[i].set(acc_arr[i] + 2.)
|
||||
return lax.cond(i % 2 == 0,
|
||||
arr1,
|
||||
lambda arr1: ops.index_update(arr1, i, arr1[i] + 1.),
|
||||
lambda arr1: arr1.at[i].set(arr1[i] + 1.),
|
||||
arr1,
|
||||
lambda arr1: arr1)
|
||||
arr = lax.fori_loop(0, arr.shape[0], loop_body, arr)
|
||||
@ -191,9 +191,9 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
with loops.Scope() as s:
|
||||
s.arr = jnp.zeros(5) # Must create the mutable state of the loop as `scope` fields.
|
||||
for i in s.range(s.arr.shape[0]):
|
||||
s.arr = ops.index_update(s.arr, i, s.arr[i] + 2.)
|
||||
s.arr = s.arr.at[i].set(s.arr[i] + 2.)
|
||||
for _ in s.cond_range(i % 2 == 0): # Conditionals are also sugared as loops with 0 or 1 iterations
|
||||
s.arr = ops.index_update(s.arr, i, s.arr[i] + 1.)
|
||||
s.arr = s.arr.at[i].set(s.arr[i] + 1.)
|
||||
return s.arr
|
||||
|
||||
self.assertAllClose(f_expected(), f_op_jax())
|
||||
|
Loading…
x
Reference in New Issue
Block a user