diff --git a/jax/__init__.py b/jax/__init__.py index 110725e90..4b9fd3b45 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -126,6 +126,7 @@ from . import errors as errors from . import image as image from . import lax as lax from . import nn as nn +from . import ops as ops from . import profiler as profiler from . import random as random from . import tree_util as tree_util diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 3a0b29262..2c57f0a0a 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -22,7 +22,6 @@ from jax._src.numpy.vectorize import vectorize from jax._src import ad_util from jax._src import api from jax import lax -from jax import ops from jax._src import dtypes from jax.interpreters import xla from jax.interpreters import ad @@ -751,8 +750,8 @@ def _lu_pivots_body_fn(i, permutation_and_swaps): iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims)) x = permutation[..., i] y = permutation[iotas + (j,)] - permutation = ops.index_update(permutation, ops.index[..., i], y) - return ops.index_update(permutation, ops.index[iotas + (j,)], x), swaps + permutation = permutation.at[..., i].set(y) + return permutation.at[iotas + (j,)].set(x), swaps @partial(api.jit, static_argnums=(1,)) @@ -853,16 +852,13 @@ def _lu_unblocked(a): else: magnitude = jnp.abs(a[:, k]) i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf)) - pivot = ops.index_update(pivot, ops.index[k], i) - - a = ops.index_update(a, ops.index[[k, i],], a[[i, k],]) - - perm = ops.index_update(perm, ops.index[[i, k],], perm[[k, i],]) + pivot = pivot.at[k].set(i) + a = a.at[[k, i],].set(a[[i, k],]) + perm = perm.at[[i, k],].set(perm[[k, i],]) # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes x = a[k, k] - a = ops.index_update(a, ops.index[:, k], - jnp.where(m_idx > k, a[:, k] / x, a[:, k])) + a = a.at[:, k].set(jnp.where(m_idx > k, a[:, k] / x, a[:, k])) # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:]) a = a - jnp.where((m_idx[:, None] > k) & (n_idx > k), @@ -888,20 +884,17 @@ def _lu_blocked(a, block_size=128): b = min(r - k, block_size) block_pivot, block_perm, lu_block = _lu_unblocked(a[k:, k:k+b]) - pivot = ops.index_update(pivot, ops.index[k:k+b], block_pivot + k) - perm = ops.index_update(perm, ops.index[k:], perm[block_perm + k]) - a = ops.index_update(a, ops.index[k:, :], a[block_perm + k, :]) - a = ops.index_update(a, ops.index[k:, k:k+b], lu_block) + pivot = pivot.at[k:k+b].set(block_pivot + k) + perm = perm.at[k:].set(perm[block_perm + k]) + a = a.at[k:, :].set(a[block_perm + k, :]) + a = a.at[k:, k:k+b].set(lu_block) if k + b < n: - a = ops.index_update( - a, ops.index[k:k+b, k+b:], - triangular_solve(a[k:k+b, k:k+b], a[k:k+b, k+b:], - left_side=True, lower=True, unit_diagonal=True)) - a = ops.index_add( - a, ops.index[k+b:, k+b:], - -lax.dot(a[k+b:, k:k+b], a[k:k+b, k+b:], - precision=lax.Precision.HIGHEST)) + a = a.at[k:k+b, k+b:].set( + triangular_solve(a[k:k+b, k:k+b], a[k:k+b, k+b:], left_side=True, + lower=True, unit_diagonal=True)) + a = a.at[k+b:, k+b:].add(-lax.dot(a[k+b:, k:k+b], a[k:k+b, k+b:], + precision=lax.Precision.HIGHEST)) return a, pivot, perm def _lu_python(x): diff --git a/jax/_src/lax/polar.py b/jax/_src/lax/polar.py index 6e3ca9b63..f2033f055 100644 --- a/jax/_src/lax/polar.py +++ b/jax/_src/lax/polar.py @@ -32,7 +32,7 @@ import jax.scipy as jsp def _add_to_diagonal(X, val): new_diagonal = X.diagonal() + val diag_indices = jnp.diag_indices(X.shape[0]) - return jax.ops.index_update(X, diag_indices, new_diagonal) + return X.at[diag_indices].set(new_diagonal) @jax.jit diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index a816cbabd..f55fb5aca 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -24,7 +24,6 @@ import numpy as np import jax.numpy as jnp from jax import lax -from jax import ops from jax import random from jax import core from jax._src.util import prod @@ -186,12 +185,11 @@ def delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float_): W = jnp.zeros(shape, dtype=dtype) if len(shape) == 3: k = shape[0] - return ops.index_update(W, ops.index[(k-1)//2, ...], ortho_matrix) + return W.at[(k-1)//2, ...].set(ortho_matrix) elif len(shape) == 4: k1, k2 = shape[:2] - return ops.index_update(W, ops.index[(k1-1)//2, (k2-1)//2, ...], ortho_matrix) + return W.at[(k1-1)//2, (k2-1)//2, ...].set(ortho_matrix) else: k1, k2, k3 = shape[:3] - return ops.index_update(W, ops.index[(k1-1)//2, (k2-1)//2, (k3-1)//2, ...], - ortho_matrix) + return W.at[(k1-1)//2, (k2-1)//2, (k3-1)//2, ...].set(ortho_matrix) return init diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index b7ab91ee2..0f3fbe29f 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -21,7 +21,6 @@ from jax.lib import xla_client from jax._src.util import safe_zip from .util import _wraps from . import lax_numpy as jnp -from jax import ops as jaxops def _fft_core(func_name, fft_type, a, s, axes, norm): @@ -203,19 +202,17 @@ def fftfreq(n, d=1.0): k = jnp.zeros(n) if n % 2 == 0: # k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1) - k = jaxops.index_update(k, jaxops.index[0: n // 2], jnp.arange(0, n // 2)) + k = k.at[0: n // 2].set( jnp.arange(0, n // 2)) # k[n // 2:] = jnp.arange(-n // 2, -1) - k = jaxops.index_update(k, jaxops.index[n // 2:], jnp.arange(-n // 2, 0)) + k = k.at[n // 2:].set( jnp.arange(-n // 2, 0)) else: # k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2) - k = jaxops.index_update(k, jaxops.index[0: (n - 1) // 2 + 1], - jnp.arange(0, (n - 1) // 2 + 1)) + k = k.at[0: (n - 1) // 2 + 1].set(jnp.arange(0, (n - 1) // 2 + 1)) # k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1) - k = jaxops.index_update(k, jaxops.index[(n - 1) // 2 + 1:], - jnp.arange(-(n - 1) // 2, 0)) + k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0)) return k / (d * n) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 2af6c54db..481b28ffe 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -50,7 +50,6 @@ from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray from jax.interpreters import pxla from jax import lax from jax._src.lax.lax import _device_put_raw -from jax import ops from jax._src.ops import scatter from jax._src.util import (partial, unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio, canonicalize_axis as _canonicalize_axis, maybe_named_axis) @@ -4035,10 +4034,8 @@ def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None): # Cumsum to get indices of new number in repeated tensor, e.g. [0, 1, 3, 3] scatter_indices = cumsum(exclusive_repeats) # Scatter these onto a zero buffer, e.g. [1,1,0,2,0,0,0,0] - block_split_indicators = ops.index_add( - x=zeros([total_repeat_length], dtype=int32), - idx=scatter_indices, - y=1) + block_split_indicators = zeros([total_repeat_length], dtype=int32) + block_split_indicators = block_split_indicators.at[scatter_indices].add(1) # Cumsum again to get scatter indices for repeat, e.g. [0,1,1,3,3,3,3,3] gather_indices = cumsum(block_split_indicators) - 1 return take(a, gather_indices, axis=axis) @@ -4217,7 +4214,7 @@ def diagflat(v, k=0): fi = i+k+i*adj_length else: fi = i+(i-k)*adj_length - res = ops.index_update(res, ops.index[fi], v) + res = res.at[fi].set(v) res = res.reshape(adj_length, adj_length) return res diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index fb83613bc..8867aa1ee 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -22,7 +22,6 @@ from typing import Tuple, Union, cast from jax import jit, custom_jvp from jax import lax -from jax import ops from jax._src.lax import linalg as lax_linalg from jax._src import dtypes from .util import _wraps @@ -218,7 +217,7 @@ def _cofactor_solve(a, b): # partial_det[:, -1] contains the full determinant and # partial_det[:, -2] contains det(u) / u_{nn}. partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None] - lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2]) + lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2]) permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1],)) iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,))) # filter out any matrices that are not full rank diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 5a3dc10ca..58c6a0d1a 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -20,7 +20,6 @@ from . import lax_numpy as jnp from jax import jit from .util import _wraps from .linalg import eigvals as _eigvals -from jax import ops as jaxops def _to_inexact_type(type): @@ -38,7 +37,7 @@ def _roots_no_zeros(p): # build companion matrix and find its eigenvalues (the roots) A = jnp.diag(jnp.ones((p.size - 2,), p.dtype), -1) - A = jaxops.index_update(A, jaxops.index[0, :], -p[1:] / p[0]) + A = A.at[0, :].set(-p[1:] / p[0]) roots = _eigvals(A) return roots diff --git a/jax/_src/scipy/eigh.py b/jax/_src/scipy/eigh.py index 6de93c133..febbd2243 100644 --- a/jax/_src/scipy/eigh.py +++ b/jax/_src/scipy/eigh.py @@ -138,7 +138,7 @@ def split_spectrum(H, split_point, V0=None, precision=lax.Precision.HIGHEST): Vp: An isometry from the input space of `V0` to `Hp`. """ def _fill_diagonal(X, vals): - return jax.ops.index_update(X, jnp.diag_indices(X.shape[0]), vals) + return X.at[jnp.diag_indices(X.shape[0])].set(vals) H_shift = _fill_diagonal(H, H.diagonal() - split_point) U, _ = jsp.linalg.polar_unitary(H_shift) diff --git a/jax/_src/scipy/optimize/_lbfgs.py b/jax/_src/scipy/optimize/_lbfgs.py index d066f02ce..a46cd37a8 100644 --- a/jax/_src/scipy/optimize/_lbfgs.py +++ b/jax/_src/scipy/optimize/_lbfgs.py @@ -17,7 +17,7 @@ from functools import partial import jax import jax.numpy as jnp -from jax import lax, ops +from jax import lax from .line_search import line_search _dot = partial(jnp.dot, precision=lax.Precision.HIGHEST) @@ -206,7 +206,7 @@ def _two_loop_recursion(state: LBFGSResults): i = his_size - 1 - j _q, _a_his = carry a_i = state.rho_history[i] * jnp.real(_dot(jnp.conj(state.s_history[i]), _q)) - _a_his = ops.index_update(_a_his, ops.index[i], a_i) + _a_his = _a_his.at[i].set(a_i) _q = _q - a_i * jnp.conj(state.y_history[i]) return _q, _a_his @@ -225,9 +225,9 @@ def _two_loop_recursion(state: LBFGSResults): def _update_history_vectors(history, new): # TODO(Jakob-Unfried) use rolling buffer instead? See #6053 - return ops.index_update(jnp.roll(history, -1, axis=0), ops.index[-1, :], new) + return jnp.roll(history, -1, axis=0).at[-1, :].set(new) def _update_history_scalars(history, new): # TODO(Jakob-Unfried) use rolling buffer instead? See #6053 - return ops.index_update(jnp.roll(history, -1, axis=0), ops.index[-1], new) + return jnp.roll(history, -1, axis=0).at[-1].set(new) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index c9e3f5909..4840c8cb0 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -20,7 +20,6 @@ import scipy.special as osp_special from jax._src import api from jax import jit from jax import lax, core -from jax import ops from jax.interpreters import ad from jax._src.numpy import lax_numpy as jnp from jax._src.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like, @@ -765,14 +764,14 @@ def _gen_derivatives(p: jnp.ndarray, p_p1 = p[1, 1:num_l - 1, :] coeff = -1.0 / ((l_vec + 1) * l_vec) update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1) - p_mm2_lm1 = p_mm2_lm1.at[ops.index[1, 2:num_l, :]].set(update_p_p1) + p_mm2_lm1 = p_mm2_lm1.at[1, 2:num_l, :].set(update_p_p1) if num_l > 2: l_vec = jnp.arange(2, num_l - 1) p_p2 = p[2, 2:num_l - 1, :] coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec) update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2) - p_mm2_lm1 = p_mm2_lm1.at[ops.index[0, 3:num_l, :]].set(update_p_p2) + p_mm2_lm1 = p_mm2_lm1.at[0, 3:num_l, :].set(update_p_p2) m_mat, l_mat = jnp.mgrid[:num_m, :num_l] diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index bf2d32e53..a30da083b 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -279,17 +279,13 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): @parameterized.named_parameters( jtu.cases_from_list( - dict(testcase_name=f"_{op.__name__}", op=op) for op in ( - jax.ops.index_add, - jax.ops.index_max, - jax.ops.index_min, - jax.ops.index_mul, - jax.ops.index_update, + dict(testcase_name=f"_{op}", op=op) for op in ( + "add", "max", "min", "multiply", "set" ))) def test_scatter_static(self, op): values = np.ones((5, 6), dtype=np.float32) update = np.float32(6.) - f_jax = jax.jit(lambda v, u: op(v, jax.ops.index[::2, 3:], u)) + f_jax = jax.jit(lambda v, u: getattr(v.at[::2, 3:], op)(u)) self.ConvertAndCompare(f_jax, values, update) @parameterized.named_parameters( diff --git a/jax/experimental/loops.py b/jax/experimental/loops.py index 08e1a0d0b..4a05082ab 100644 --- a/jax/experimental/loops.py +++ b/jax/experimental/loops.py @@ -34,10 +34,10 @@ returns an updated array, e.g.:: arr = np.zeros(5) def loop_body(i, acc_arr): - arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.) + arr1 = acc_arr.at[i].set(acc_arr[i] + 2.) return lax.cond(i % 2 == 0, arr1, - lambda arr1: ops.index_update(arr1, i, arr1[i] + 1), + lambda arr1: arr1.at[i].set(arr1[i] + 1), arr1, lambda arr1: arr1) arr = lax.fori_loop(0, arr.shape[0], loop_body, arr) @@ -52,9 +52,9 @@ special `loops.scope` object and use `for` loops over special with loops.Scope() as s: s.arr = np.zeros(5) # Create the mutable state of the loop as `scope` fields. for i in s.range(s.arr.shape[0]): - s.arr = ops.index_update(s.arr, i, s.arr[i] + 2.) + s.arr = s.arr.at[i].set(s.arr[i] + 2.) for _ in s.cond_range(i % 2 == 0): # Conditionals as loops with 0 or 1 iterations - s.arr = ops.index_update(s.arr, i, s.arr[i] + 1.) + s.arr = s.arr.at[i].set(s.arr[i] + 1.) Loops constructed with `range` must have literal constant bounds. If you need loops with dynamic bounds, you can use the more general `while_range` iterator. diff --git a/tests/batching_test.py b/tests/batching_test.py index b19dc6688..aea61426d 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -28,7 +28,6 @@ from jax import random from jax import jit, grad, jvp, vjp, make_jaxpr, jacfwd, jacrev, hessian from jax import vmap from jax._src.util import partial -import jax.ops from jax.config import config config.parse_flags_with_absl() @@ -936,7 +935,7 @@ class BatchingTest(jtu.JaxTestCase): self.assertEqual((), empty_tuple) def testIndexAddBatchedIndexesOnly(self): - f = lambda x, idx, y: jax.ops.index_add(x, jax.ops.index[idx], y) + f = lambda x, idx, y: jnp.asarray(x).at[idx].add(y) result = vmap(f, (None, 0, None))(np.zeros((10,)), np.arange(10,), 1.) self.assertAllClose(result, np.eye(10), check_dtypes=False) diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index ccb2a3e21..c5b779f4f 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -948,7 +948,7 @@ class LaxAutodiffTest(jtu.JaxTestCase): def f(x): n = x.shape[0] y = np.arange(n, dtype=x.dtype) - return jax.ops.index_update(x, np.diag_indices(n), y) + return jax.device_put(x).at[np.diag_indices(n)].set(y) rng = jtu.rand_default(self.rng()) check_grads(f, (rng((5, 5), np.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 4dfb6bc63..792314965 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1711,10 +1711,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def jnp_pad_with(vector, pad_width, iaxis, kwargs): pad_value = kwargs.get('padder', 10) - vector = jax.ops.index_update( - vector, jax.ops.index[:pad_width[0]], pad_value) - vector = jax.ops.index_update( - vector, jax.ops.index[-pad_width[1]:], pad_value) + vector = vector.at[:pad_width[0]].set(pad_value) + vector = vector.at[-pad_width[1]:].set(pad_value) return vector arr = np.arange(6).reshape(2, 3) diff --git a/tests/loops_test.py b/tests/loops_test.py index 08fa2ca3b..26cae5a85 100644 --- a/tests/loops_test.py +++ b/tests/loops_test.py @@ -20,7 +20,7 @@ import numpy as np import re import jax -from jax import lax, ops +from jax import lax from jax import numpy as jnp from jax import test_util as jtu from jax.experimental import loops @@ -87,7 +87,7 @@ class LoopsTest(jtu.JaxTestCase): assert n == y.shape[0] s.out = jnp.zeros(shape=[n], dtype=jnp.float32) for i in s.range(n): - s.out = ops.index_add(s.out, i, x[i] + y[i]) + s.out = s.out.at[i].add(x[i] + y[i]) return s.out x = jnp.array([1., 2., 3.], dtype=jnp.float32) @@ -104,7 +104,7 @@ class LoopsTest(jtu.JaxTestCase): for i in s.range(n): for j in s.range(p): for k in s.range(m): - s.out = ops.index_add(s.out, (i, j), x[i, k] * y[k, j]) + s.out = s.out.at[(i, j)].add(x[i, k] * y[k, j]) return s.out x = jnp.array([[1., 2., 3.]], dtype=jnp.float32) # 1x3 @@ -178,10 +178,10 @@ class LoopsTest(jtu.JaxTestCase): def f_op_jax(): arr = jnp.zeros(5) def loop_body(i, acc_arr): - arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.) + arr1 = acc_arr.at[i].set(acc_arr[i] + 2.) return lax.cond(i % 2 == 0, arr1, - lambda arr1: ops.index_update(arr1, i, arr1[i] + 1.), + lambda arr1: arr1.at[i].set(arr1[i] + 1.), arr1, lambda arr1: arr1) arr = lax.fori_loop(0, arr.shape[0], loop_body, arr) @@ -191,9 +191,9 @@ class LoopsTest(jtu.JaxTestCase): with loops.Scope() as s: s.arr = jnp.zeros(5) # Must create the mutable state of the loop as `scope` fields. for i in s.range(s.arr.shape[0]): - s.arr = ops.index_update(s.arr, i, s.arr[i] + 2.) + s.arr = s.arr.at[i].set(s.arr[i] + 2.) for _ in s.cond_range(i % 2 == 0): # Conditionals are also sugared as loops with 0 or 1 iterations - s.arr = ops.index_update(s.arr, i, s.arr[i] + 1.) + s.arr = s.arr.at[i].set(s.arr[i] + 1.) return s.arr self.assertAllClose(f_expected(), f_op_jax())