Merge pull request #1394 from j-towns/fix-scatter-caching

Ensure all ops get cache hits on second op-by-op mode call
This commit is contained in:
Matthew Johnson 2019-09-26 06:48:42 -07:00 committed by GitHub
commit 762b602f33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 55 additions and 36 deletions

View File

@ -785,6 +785,9 @@ def scatter_max(operand, scatter_indices, updates, dimension_numbers):
update_consts=consts, dimension_numbers=dimension_numbers,
updates_shape=updates.shape)
# Define this outside of scatter to ensure cache hits.
_scatter_reduction_computation = lambda x, y: y
def scatter(operand, scatter_indices, updates, dimension_numbers):
"""Scatter-update operator.
@ -809,7 +812,8 @@ def scatter(operand, scatter_indices, updates, dimension_numbers):
Returns:
An array containing the sum of `operand` and the scattered updates.
"""
jaxpr, consts = _reduction_jaxpr(lambda x, y: y, _abstractify(_const(operand, 0)))
jaxpr, consts = _reduction_jaxpr(_scatter_reduction_computation,
_abstractify(_const(operand, 0)))
return scatter_p.bind(
operand, scatter_indices, updates, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,

View File

@ -79,6 +79,17 @@ class FixedPointError(Exception): pass
### fori_loop and while_loop
def _fori_cond_fun(loop_carry):
i, upper, _ = loop_carry
return lax.lt(i, upper)
@cache()
def _fori_body_fun(body_fun):
def while_body_fun(loop_carry):
i, upper, x = loop_carry
return lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)
return while_body_fun
def fori_loop(lower, upper, body_fun, init_val):
"""Loop from ``lower`` to ``upper`` by reduction to ``while_loop``.
@ -108,15 +119,8 @@ def fori_loop(lower, upper, body_fun, init_val):
Returns:
Loop value from the final iteration, of type ``a``.
"""
def while_cond_fun(loop_carry):
i, _ = loop_carry
return lax.lt(i, upper)
def while_body_fun(loop_carry):
i, x = loop_carry
return lax.add(i, lax._const(i, 1)), body_fun(i, x)
_, result = while_loop(while_cond_fun, while_body_fun, (lower, init_val))
_, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun),
(lower, upper, init_val))
return result

View File

@ -593,6 +593,17 @@ xla.backend_specific_translations['gpu'][lu_p] = partial(
_lu_cpu_gpu_translation_rule, cusolver.getrf)
# Define this outside lu_pivots_to_permutation to ensure fori_loop cache hits
def _lu_pivots_body_fn(i, permutation_and_swaps):
permutation, swaps = permutation_and_swaps
batch_dims = swaps.shape[:-1]
j = swaps[..., i]
iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims))
x = permutation[..., i]
y = permutation[iotas + (j,)]
permutation = ops.index_update(permutation, ops.index[..., i], y)
return ops.index_update(permutation, ops.index[iotas + (j,)], x), swaps
def lu_pivots_to_permutation(swaps, m):
"""Converts the pivots (row swaps) returned by LU to a permutation.
@ -609,18 +620,11 @@ def lu_pivots_to_permutation(swaps, m):
batch_dims = swaps.shape[:-1]
k = swaps.shape[-1]
def body_fn(i, permutation):
j = swaps[..., i]
iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims))
x = permutation[..., i]
y = permutation[iotas + (j,)]
permutation = ops.index_update(permutation, ops.index[..., i], y)
return ops.index_update(permutation, ops.index[iotas + (j,)], x)
permutation = lax.broadcasted_iota(np.int32, batch_dims + (m,),
len(batch_dims))
return lax.fori_loop(
onp.array(0, onp.int32), onp.array(k, onp.int32), body_fn, permutation)
result, _ = lax.fori_loop(onp.array(0, onp.int32), onp.array(k, onp.int32),
_lu_pivots_body_fn, (permutation, swaps))
return result
# QR decomposition

View File

@ -2833,23 +2833,24 @@ hanning = _wrap_numpy_nullary_function(onp.hanning)
# TODO: lower `kaiser` via lax to allow non-constant beta values.
kaiser = _wrap_numpy_nullary_function(onp.kaiser)
def _gcd_cond_fn(xs):
x1, x2 = xs
return any(x2 != 0)
def _gcd_body_fn(xs):
x1, x2 = xs
x1, x2 = (where(x2 != 0, x2, x1),
where(x2 != 0, lax.rem(x1, x2), lax._const(x2, 0)))
return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2))
@_wraps(getattr(onp, "gcd", None))
def gcd(x1, x2):
if (not issubdtype(_dtype(x1), integer) or
not issubdtype(_dtype(x2), integer)):
raise ValueError("Arguments to gcd must be integers.")
def cond_fn(xs):
x1, x2 = xs
return any(x2 != 0)
def body_fn(xs):
x1, x2 = xs
x1, x2 = (where(x2 != 0, x2, x1),
where(x2 != 0, lax.rem(x1, x2), lax._const(x2, 0)))
return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2))
x1, x2 = _promote_dtypes(lax.abs(x1), lax.abs(x2))
x1, x2 = broadcast_arrays(x1, x2)
gcd, _ = lax.while_loop(cond_fn, body_fn, (x1, x2))
gcd, _ = lax.while_loop(_gcd_cond_fn, _gcd_body_fn, (x1, x2))
return gcd

View File

@ -35,6 +35,7 @@ from .config import flags
from .util import partial
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
from .lib import xla_bridge
from .interpreters import xla
FLAGS = flags.FLAGS
@ -566,6 +567,13 @@ class JaxTestCase(parameterized.TestCase):
python_should_be_executing = True
python_ans = fun(*args)
cache_misses = xla.xla_primitive_callable.cache_info().misses
python_ans = fun(*args)
self.assertEqual(
cache_misses, xla.xla_primitive_callable.cache_info().misses,
"Compilation detected during second call of {} in op-by-op "
"mode.".format(fun))
cfun = api.jit(wrapped_fun)
python_should_be_executing = True
monitored_ans = cfun(*args)

View File

@ -300,19 +300,17 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testForiLoopBasic(self):
def body_fun(i, tot):
return lax.add(tot, i)
def count(num):
def body_fun(i, tot):
return lax.add(tot, i)
return lax.fori_loop(0, num, body_fun, 0)
cfun = api.jit(count)
self.assertEqual(count(2), 1)
self.assertEqual(count(2), cfun(2))
self.assertEqual(count(3), 3)
self.assertEqual(count(3), cfun(3))
self.assertEqual(count(4), 6)
self.assertEqual(count(4), cfun(4))
for args_maker in [lambda: [2], lambda: [3], lambda: [4]]:
self._CompileAndCheck(count, args_maker, True)
def testForiLoopClosure(self):
def count(num):