Switch internal users of jax.ops.index_... to use x.at[x].set() APIs.

This commit is contained in:
Peter Hawkins 2021-09-13 16:40:45 -04:00
parent 4d68a79921
commit a84426cb8f
17 changed files with 54 additions and 78 deletions

View File

@ -126,6 +126,7 @@ from . import errors as errors
from . import image as image
from . import lax as lax
from . import nn as nn
from . import ops as ops
from . import profiler as profiler
from . import random as random
from . import tree_util as tree_util

View File

@ -22,7 +22,6 @@ from jax._src.numpy.vectorize import vectorize
from jax._src import ad_util
from jax._src import api
from jax import lax
from jax import ops
from jax._src import dtypes
from jax.interpreters import xla
from jax.interpreters import ad
@ -751,8 +750,8 @@ def _lu_pivots_body_fn(i, permutation_and_swaps):
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims))
x = permutation[..., i]
y = permutation[iotas + (j,)]
permutation = ops.index_update(permutation, ops.index[..., i], y)
return ops.index_update(permutation, ops.index[iotas + (j,)], x), swaps
permutation = permutation.at[..., i].set(y)
return permutation.at[iotas + (j,)].set(x), swaps
@partial(api.jit, static_argnums=(1,))
@ -853,16 +852,13 @@ def _lu_unblocked(a):
else:
magnitude = jnp.abs(a[:, k])
i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf))
pivot = ops.index_update(pivot, ops.index[k], i)
a = ops.index_update(a, ops.index[[k, i],], a[[i, k],])
perm = ops.index_update(perm, ops.index[[i, k],], perm[[k, i],])
pivot = pivot.at[k].set(i)
a = a.at[[k, i],].set(a[[i, k],])
perm = perm.at[[i, k],].set(perm[[k, i],])
# a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes
x = a[k, k]
a = ops.index_update(a, ops.index[:, k],
jnp.where(m_idx > k, a[:, k] / x, a[:, k]))
a = a.at[:, k].set(jnp.where(m_idx > k, a[:, k] / x, a[:, k]))
# a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:])
a = a - jnp.where((m_idx[:, None] > k) & (n_idx > k),
@ -888,20 +884,17 @@ def _lu_blocked(a, block_size=128):
b = min(r - k, block_size)
block_pivot, block_perm, lu_block = _lu_unblocked(a[k:, k:k+b])
pivot = ops.index_update(pivot, ops.index[k:k+b], block_pivot + k)
perm = ops.index_update(perm, ops.index[k:], perm[block_perm + k])
a = ops.index_update(a, ops.index[k:, :], a[block_perm + k, :])
a = ops.index_update(a, ops.index[k:, k:k+b], lu_block)
pivot = pivot.at[k:k+b].set(block_pivot + k)
perm = perm.at[k:].set(perm[block_perm + k])
a = a.at[k:, :].set(a[block_perm + k, :])
a = a.at[k:, k:k+b].set(lu_block)
if k + b < n:
a = ops.index_update(
a, ops.index[k:k+b, k+b:],
triangular_solve(a[k:k+b, k:k+b], a[k:k+b, k+b:],
left_side=True, lower=True, unit_diagonal=True))
a = ops.index_add(
a, ops.index[k+b:, k+b:],
-lax.dot(a[k+b:, k:k+b], a[k:k+b, k+b:],
precision=lax.Precision.HIGHEST))
a = a.at[k:k+b, k+b:].set(
triangular_solve(a[k:k+b, k:k+b], a[k:k+b, k+b:], left_side=True,
lower=True, unit_diagonal=True))
a = a.at[k+b:, k+b:].add(-lax.dot(a[k+b:, k:k+b], a[k:k+b, k+b:],
precision=lax.Precision.HIGHEST))
return a, pivot, perm
def _lu_python(x):

View File

@ -32,7 +32,7 @@ import jax.scipy as jsp
def _add_to_diagonal(X, val):
new_diagonal = X.diagonal() + val
diag_indices = jnp.diag_indices(X.shape[0])
return jax.ops.index_update(X, diag_indices, new_diagonal)
return X.at[diag_indices].set(new_diagonal)
@jax.jit

View File

@ -24,7 +24,6 @@ import numpy as np
import jax.numpy as jnp
from jax import lax
from jax import ops
from jax import random
from jax import core
from jax._src.util import prod
@ -186,12 +185,11 @@ def delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float_):
W = jnp.zeros(shape, dtype=dtype)
if len(shape) == 3:
k = shape[0]
return ops.index_update(W, ops.index[(k-1)//2, ...], ortho_matrix)
return W.at[(k-1)//2, ...].set(ortho_matrix)
elif len(shape) == 4:
k1, k2 = shape[:2]
return ops.index_update(W, ops.index[(k1-1)//2, (k2-1)//2, ...], ortho_matrix)
return W.at[(k1-1)//2, (k2-1)//2, ...].set(ortho_matrix)
else:
k1, k2, k3 = shape[:3]
return ops.index_update(W, ops.index[(k1-1)//2, (k2-1)//2, (k3-1)//2, ...],
ortho_matrix)
return W.at[(k1-1)//2, (k2-1)//2, (k3-1)//2, ...].set(ortho_matrix)
return init

View File

@ -21,7 +21,6 @@ from jax.lib import xla_client
from jax._src.util import safe_zip
from .util import _wraps
from . import lax_numpy as jnp
from jax import ops as jaxops
def _fft_core(func_name, fft_type, a, s, axes, norm):
@ -203,19 +202,17 @@ def fftfreq(n, d=1.0):
k = jnp.zeros(n)
if n % 2 == 0:
# k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1)
k = jaxops.index_update(k, jaxops.index[0: n // 2], jnp.arange(0, n // 2))
k = k.at[0: n // 2].set( jnp.arange(0, n // 2))
# k[n // 2:] = jnp.arange(-n // 2, -1)
k = jaxops.index_update(k, jaxops.index[n // 2:], jnp.arange(-n // 2, 0))
k = k.at[n // 2:].set( jnp.arange(-n // 2, 0))
else:
# k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2)
k = jaxops.index_update(k, jaxops.index[0: (n - 1) // 2 + 1],
jnp.arange(0, (n - 1) // 2 + 1))
k = k.at[0: (n - 1) // 2 + 1].set(jnp.arange(0, (n - 1) // 2 + 1))
# k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1)
k = jaxops.index_update(k, jaxops.index[(n - 1) // 2 + 1:],
jnp.arange(-(n - 1) // 2, 0))
k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0))
return k / (d * n)

View File

@ -50,7 +50,6 @@ from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray
from jax.interpreters import pxla
from jax import lax
from jax._src.lax.lax import _device_put_raw
from jax import ops
from jax._src.ops import scatter
from jax._src.util import (partial, unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
@ -4035,10 +4034,8 @@ def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None):
# Cumsum to get indices of new number in repeated tensor, e.g. [0, 1, 3, 3]
scatter_indices = cumsum(exclusive_repeats)
# Scatter these onto a zero buffer, e.g. [1,1,0,2,0,0,0,0]
block_split_indicators = ops.index_add(
x=zeros([total_repeat_length], dtype=int32),
idx=scatter_indices,
y=1)
block_split_indicators = zeros([total_repeat_length], dtype=int32)
block_split_indicators = block_split_indicators.at[scatter_indices].add(1)
# Cumsum again to get scatter indices for repeat, e.g. [0,1,1,3,3,3,3,3]
gather_indices = cumsum(block_split_indicators) - 1
return take(a, gather_indices, axis=axis)
@ -4217,7 +4214,7 @@ def diagflat(v, k=0):
fi = i+k+i*adj_length
else:
fi = i+(i-k)*adj_length
res = ops.index_update(res, ops.index[fi], v)
res = res.at[fi].set(v)
res = res.reshape(adj_length, adj_length)
return res

View File

@ -22,7 +22,6 @@ from typing import Tuple, Union, cast
from jax import jit, custom_jvp
from jax import lax
from jax import ops
from jax._src.lax import linalg as lax_linalg
from jax._src import dtypes
from .util import _wraps
@ -218,7 +217,7 @@ def _cofactor_solve(a, b):
# partial_det[:, -1] contains the full determinant and
# partial_det[:, -2] contains det(u) / u_{nn}.
partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2])
lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2])
permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1],))
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,)))
# filter out any matrices that are not full rank

View File

@ -20,7 +20,6 @@ from . import lax_numpy as jnp
from jax import jit
from .util import _wraps
from .linalg import eigvals as _eigvals
from jax import ops as jaxops
def _to_inexact_type(type):
@ -38,7 +37,7 @@ def _roots_no_zeros(p):
# build companion matrix and find its eigenvalues (the roots)
A = jnp.diag(jnp.ones((p.size - 2,), p.dtype), -1)
A = jaxops.index_update(A, jaxops.index[0, :], -p[1:] / p[0])
A = A.at[0, :].set(-p[1:] / p[0])
roots = _eigvals(A)
return roots

View File

@ -138,7 +138,7 @@ def split_spectrum(H, split_point, V0=None, precision=lax.Precision.HIGHEST):
Vp: An isometry from the input space of `V0` to `Hp`.
"""
def _fill_diagonal(X, vals):
return jax.ops.index_update(X, jnp.diag_indices(X.shape[0]), vals)
return X.at[jnp.diag_indices(X.shape[0])].set(vals)
H_shift = _fill_diagonal(H, H.diagonal() - split_point)
U, _ = jsp.linalg.polar_unitary(H_shift)

View File

@ -17,7 +17,7 @@ from functools import partial
import jax
import jax.numpy as jnp
from jax import lax, ops
from jax import lax
from .line_search import line_search
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
@ -206,7 +206,7 @@ def _two_loop_recursion(state: LBFGSResults):
i = his_size - 1 - j
_q, _a_his = carry
a_i = state.rho_history[i] * jnp.real(_dot(jnp.conj(state.s_history[i]), _q))
_a_his = ops.index_update(_a_his, ops.index[i], a_i)
_a_his = _a_his.at[i].set(a_i)
_q = _q - a_i * jnp.conj(state.y_history[i])
return _q, _a_his
@ -225,9 +225,9 @@ def _two_loop_recursion(state: LBFGSResults):
def _update_history_vectors(history, new):
# TODO(Jakob-Unfried) use rolling buffer instead? See #6053
return ops.index_update(jnp.roll(history, -1, axis=0), ops.index[-1, :], new)
return jnp.roll(history, -1, axis=0).at[-1, :].set(new)
def _update_history_scalars(history, new):
# TODO(Jakob-Unfried) use rolling buffer instead? See #6053
return ops.index_update(jnp.roll(history, -1, axis=0), ops.index[-1], new)
return jnp.roll(history, -1, axis=0).at[-1].set(new)

View File

@ -20,7 +20,6 @@ import scipy.special as osp_special
from jax._src import api
from jax import jit
from jax import lax, core
from jax import ops
from jax.interpreters import ad
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.lax_numpy import (asarray, _reduction_dims, _constant_like,
@ -765,14 +764,14 @@ def _gen_derivatives(p: jnp.ndarray,
p_p1 = p[1, 1:num_l - 1, :]
coeff = -1.0 / ((l_vec + 1) * l_vec)
update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1)
p_mm2_lm1 = p_mm2_lm1.at[ops.index[1, 2:num_l, :]].set(update_p_p1)
p_mm2_lm1 = p_mm2_lm1.at[1, 2:num_l, :].set(update_p_p1)
if num_l > 2:
l_vec = jnp.arange(2, num_l - 1)
p_p2 = p[2, 2:num_l - 1, :]
coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec)
update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2)
p_mm2_lm1 = p_mm2_lm1.at[ops.index[0, 3:num_l, :]].set(update_p_p2)
p_mm2_lm1 = p_mm2_lm1.at[0, 3:num_l, :].set(update_p_p2)
m_mat, l_mat = jnp.mgrid[:num_m, :num_l]

View File

@ -279,17 +279,13 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
@parameterized.named_parameters(
jtu.cases_from_list(
dict(testcase_name=f"_{op.__name__}", op=op) for op in (
jax.ops.index_add,
jax.ops.index_max,
jax.ops.index_min,
jax.ops.index_mul,
jax.ops.index_update,
dict(testcase_name=f"_{op}", op=op) for op in (
"add", "max", "min", "multiply", "set"
)))
def test_scatter_static(self, op):
values = np.ones((5, 6), dtype=np.float32)
update = np.float32(6.)
f_jax = jax.jit(lambda v, u: op(v, jax.ops.index[::2, 3:], u))
f_jax = jax.jit(lambda v, u: getattr(v.at[::2, 3:], op)(u))
self.ConvertAndCompare(f_jax, values, update)
@parameterized.named_parameters(

View File

@ -34,10 +34,10 @@ returns an updated array, e.g.::
arr = np.zeros(5)
def loop_body(i, acc_arr):
arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
arr1 = acc_arr.at[i].set(acc_arr[i] + 2.)
return lax.cond(i % 2 == 0,
arr1,
lambda arr1: ops.index_update(arr1, i, arr1[i] + 1),
lambda arr1: arr1.at[i].set(arr1[i] + 1),
arr1,
lambda arr1: arr1)
arr = lax.fori_loop(0, arr.shape[0], loop_body, arr)
@ -52,9 +52,9 @@ special `loops.scope` object and use `for` loops over special
with loops.Scope() as s:
s.arr = np.zeros(5) # Create the mutable state of the loop as `scope` fields.
for i in s.range(s.arr.shape[0]):
s.arr = ops.index_update(s.arr, i, s.arr[i] + 2.)
s.arr = s.arr.at[i].set(s.arr[i] + 2.)
for _ in s.cond_range(i % 2 == 0): # Conditionals as loops with 0 or 1 iterations
s.arr = ops.index_update(s.arr, i, s.arr[i] + 1.)
s.arr = s.arr.at[i].set(s.arr[i] + 1.)
Loops constructed with `range` must have literal constant bounds. If you need
loops with dynamic bounds, you can use the more general `while_range` iterator.

View File

@ -28,7 +28,6 @@ from jax import random
from jax import jit, grad, jvp, vjp, make_jaxpr, jacfwd, jacrev, hessian
from jax import vmap
from jax._src.util import partial
import jax.ops
from jax.config import config
config.parse_flags_with_absl()
@ -936,7 +935,7 @@ class BatchingTest(jtu.JaxTestCase):
self.assertEqual((), empty_tuple)
def testIndexAddBatchedIndexesOnly(self):
f = lambda x, idx, y: jax.ops.index_add(x, jax.ops.index[idx], y)
f = lambda x, idx, y: jnp.asarray(x).at[idx].add(y)
result = vmap(f, (None, 0, None))(np.zeros((10,)), np.arange(10,), 1.)
self.assertAllClose(result, np.eye(10), check_dtypes=False)

View File

@ -948,7 +948,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
def f(x):
n = x.shape[0]
y = np.arange(n, dtype=x.dtype)
return jax.ops.index_update(x, np.diag_indices(n), y)
return jax.device_put(x).at[np.diag_indices(n)].set(y)
rng = jtu.rand_default(self.rng())
check_grads(f, (rng((5, 5), np.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2,
1.)

View File

@ -1711,10 +1711,8 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def jnp_pad_with(vector, pad_width, iaxis, kwargs):
pad_value = kwargs.get('padder', 10)
vector = jax.ops.index_update(
vector, jax.ops.index[:pad_width[0]], pad_value)
vector = jax.ops.index_update(
vector, jax.ops.index[-pad_width[1]:], pad_value)
vector = vector.at[:pad_width[0]].set(pad_value)
vector = vector.at[-pad_width[1]:].set(pad_value)
return vector
arr = np.arange(6).reshape(2, 3)

View File

@ -20,7 +20,7 @@ import numpy as np
import re
import jax
from jax import lax, ops
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
from jax.experimental import loops
@ -87,7 +87,7 @@ class LoopsTest(jtu.JaxTestCase):
assert n == y.shape[0]
s.out = jnp.zeros(shape=[n], dtype=jnp.float32)
for i in s.range(n):
s.out = ops.index_add(s.out, i, x[i] + y[i])
s.out = s.out.at[i].add(x[i] + y[i])
return s.out
x = jnp.array([1., 2., 3.], dtype=jnp.float32)
@ -104,7 +104,7 @@ class LoopsTest(jtu.JaxTestCase):
for i in s.range(n):
for j in s.range(p):
for k in s.range(m):
s.out = ops.index_add(s.out, (i, j), x[i, k] * y[k, j])
s.out = s.out.at[(i, j)].add(x[i, k] * y[k, j])
return s.out
x = jnp.array([[1., 2., 3.]], dtype=jnp.float32) # 1x3
@ -178,10 +178,10 @@ class LoopsTest(jtu.JaxTestCase):
def f_op_jax():
arr = jnp.zeros(5)
def loop_body(i, acc_arr):
arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
arr1 = acc_arr.at[i].set(acc_arr[i] + 2.)
return lax.cond(i % 2 == 0,
arr1,
lambda arr1: ops.index_update(arr1, i, arr1[i] + 1.),
lambda arr1: arr1.at[i].set(arr1[i] + 1.),
arr1,
lambda arr1: arr1)
arr = lax.fori_loop(0, arr.shape[0], loop_body, arr)
@ -191,9 +191,9 @@ class LoopsTest(jtu.JaxTestCase):
with loops.Scope() as s:
s.arr = jnp.zeros(5) # Must create the mutable state of the loop as `scope` fields.
for i in s.range(s.arr.shape[0]):
s.arr = ops.index_update(s.arr, i, s.arr[i] + 2.)
s.arr = s.arr.at[i].set(s.arr[i] + 2.)
for _ in s.cond_range(i % 2 == 0): # Conditionals are also sugared as loops with 0 or 1 iterations
s.arr = ops.index_update(s.arr, i, s.arr[i] + 1.)
s.arr = s.arr.at[i].set(s.arr[i] + 1.)
return s.arr
self.assertAllClose(f_expected(), f_op_jax())