mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #1394 from j-towns/fix-scatter-caching
Ensure all ops get cache hits on second op-by-op mode call
This commit is contained in:
commit
762b602f33
@ -785,6 +785,9 @@ def scatter_max(operand, scatter_indices, updates, dimension_numbers):
|
||||
update_consts=consts, dimension_numbers=dimension_numbers,
|
||||
updates_shape=updates.shape)
|
||||
|
||||
# Define this outside of scatter to ensure cache hits.
|
||||
_scatter_reduction_computation = lambda x, y: y
|
||||
|
||||
def scatter(operand, scatter_indices, updates, dimension_numbers):
|
||||
"""Scatter-update operator.
|
||||
|
||||
@ -809,7 +812,8 @@ def scatter(operand, scatter_indices, updates, dimension_numbers):
|
||||
Returns:
|
||||
An array containing the sum of `operand` and the scattered updates.
|
||||
"""
|
||||
jaxpr, consts = _reduction_jaxpr(lambda x, y: y, _abstractify(_const(operand, 0)))
|
||||
jaxpr, consts = _reduction_jaxpr(_scatter_reduction_computation,
|
||||
_abstractify(_const(operand, 0)))
|
||||
return scatter_p.bind(
|
||||
operand, scatter_indices, updates, update_jaxpr=jaxpr,
|
||||
update_consts=consts, dimension_numbers=dimension_numbers,
|
||||
|
@ -79,6 +79,17 @@ class FixedPointError(Exception): pass
|
||||
|
||||
### fori_loop and while_loop
|
||||
|
||||
def _fori_cond_fun(loop_carry):
|
||||
i, upper, _ = loop_carry
|
||||
return lax.lt(i, upper)
|
||||
|
||||
@cache()
|
||||
def _fori_body_fun(body_fun):
|
||||
def while_body_fun(loop_carry):
|
||||
i, upper, x = loop_carry
|
||||
return lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)
|
||||
return while_body_fun
|
||||
|
||||
def fori_loop(lower, upper, body_fun, init_val):
|
||||
"""Loop from ``lower`` to ``upper`` by reduction to ``while_loop``.
|
||||
|
||||
@ -108,15 +119,8 @@ def fori_loop(lower, upper, body_fun, init_val):
|
||||
Returns:
|
||||
Loop value from the final iteration, of type ``a``.
|
||||
"""
|
||||
def while_cond_fun(loop_carry):
|
||||
i, _ = loop_carry
|
||||
return lax.lt(i, upper)
|
||||
|
||||
def while_body_fun(loop_carry):
|
||||
i, x = loop_carry
|
||||
return lax.add(i, lax._const(i, 1)), body_fun(i, x)
|
||||
|
||||
_, result = while_loop(while_cond_fun, while_body_fun, (lower, init_val))
|
||||
_, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun),
|
||||
(lower, upper, init_val))
|
||||
return result
|
||||
|
||||
|
||||
|
@ -593,6 +593,17 @@ xla.backend_specific_translations['gpu'][lu_p] = partial(
|
||||
_lu_cpu_gpu_translation_rule, cusolver.getrf)
|
||||
|
||||
|
||||
# Define this outside lu_pivots_to_permutation to ensure fori_loop cache hits
|
||||
def _lu_pivots_body_fn(i, permutation_and_swaps):
|
||||
permutation, swaps = permutation_and_swaps
|
||||
batch_dims = swaps.shape[:-1]
|
||||
j = swaps[..., i]
|
||||
iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims))
|
||||
x = permutation[..., i]
|
||||
y = permutation[iotas + (j,)]
|
||||
permutation = ops.index_update(permutation, ops.index[..., i], y)
|
||||
return ops.index_update(permutation, ops.index[iotas + (j,)], x), swaps
|
||||
|
||||
def lu_pivots_to_permutation(swaps, m):
|
||||
"""Converts the pivots (row swaps) returned by LU to a permutation.
|
||||
|
||||
@ -609,18 +620,11 @@ def lu_pivots_to_permutation(swaps, m):
|
||||
batch_dims = swaps.shape[:-1]
|
||||
k = swaps.shape[-1]
|
||||
|
||||
def body_fn(i, permutation):
|
||||
j = swaps[..., i]
|
||||
iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims))
|
||||
x = permutation[..., i]
|
||||
y = permutation[iotas + (j,)]
|
||||
permutation = ops.index_update(permutation, ops.index[..., i], y)
|
||||
return ops.index_update(permutation, ops.index[iotas + (j,)], x)
|
||||
|
||||
permutation = lax.broadcasted_iota(np.int32, batch_dims + (m,),
|
||||
len(batch_dims))
|
||||
return lax.fori_loop(
|
||||
onp.array(0, onp.int32), onp.array(k, onp.int32), body_fn, permutation)
|
||||
result, _ = lax.fori_loop(onp.array(0, onp.int32), onp.array(k, onp.int32),
|
||||
_lu_pivots_body_fn, (permutation, swaps))
|
||||
return result
|
||||
|
||||
|
||||
# QR decomposition
|
||||
|
@ -2833,23 +2833,24 @@ hanning = _wrap_numpy_nullary_function(onp.hanning)
|
||||
# TODO: lower `kaiser` via lax to allow non-constant beta values.
|
||||
kaiser = _wrap_numpy_nullary_function(onp.kaiser)
|
||||
|
||||
def _gcd_cond_fn(xs):
|
||||
x1, x2 = xs
|
||||
return any(x2 != 0)
|
||||
|
||||
def _gcd_body_fn(xs):
|
||||
x1, x2 = xs
|
||||
x1, x2 = (where(x2 != 0, x2, x1),
|
||||
where(x2 != 0, lax.rem(x1, x2), lax._const(x2, 0)))
|
||||
return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2))
|
||||
|
||||
@_wraps(getattr(onp, "gcd", None))
|
||||
def gcd(x1, x2):
|
||||
if (not issubdtype(_dtype(x1), integer) or
|
||||
not issubdtype(_dtype(x2), integer)):
|
||||
raise ValueError("Arguments to gcd must be integers.")
|
||||
def cond_fn(xs):
|
||||
x1, x2 = xs
|
||||
return any(x2 != 0)
|
||||
def body_fn(xs):
|
||||
x1, x2 = xs
|
||||
x1, x2 = (where(x2 != 0, x2, x1),
|
||||
where(x2 != 0, lax.rem(x1, x2), lax._const(x2, 0)))
|
||||
return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2))
|
||||
x1, x2 = _promote_dtypes(lax.abs(x1), lax.abs(x2))
|
||||
x1, x2 = broadcast_arrays(x1, x2)
|
||||
gcd, _ = lax.while_loop(cond_fn, body_fn, (x1, x2))
|
||||
gcd, _ = lax.while_loop(_gcd_cond_fn, _gcd_body_fn, (x1, x2))
|
||||
return gcd
|
||||
|
||||
|
||||
|
@ -35,6 +35,7 @@ from .config import flags
|
||||
from .util import partial
|
||||
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
|
||||
from .lib import xla_bridge
|
||||
from .interpreters import xla
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
@ -566,6 +567,13 @@ class JaxTestCase(parameterized.TestCase):
|
||||
python_should_be_executing = True
|
||||
python_ans = fun(*args)
|
||||
|
||||
cache_misses = xla.xla_primitive_callable.cache_info().misses
|
||||
python_ans = fun(*args)
|
||||
self.assertEqual(
|
||||
cache_misses, xla.xla_primitive_callable.cache_info().misses,
|
||||
"Compilation detected during second call of {} in op-by-op "
|
||||
"mode.".format(fun))
|
||||
|
||||
cfun = api.jit(wrapped_fun)
|
||||
python_should_be_executing = True
|
||||
monitored_ans = cfun(*args)
|
||||
|
@ -300,19 +300,17 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testForiLoopBasic(self):
|
||||
def body_fun(i, tot):
|
||||
return lax.add(tot, i)
|
||||
|
||||
def count(num):
|
||||
def body_fun(i, tot):
|
||||
return lax.add(tot, i)
|
||||
return lax.fori_loop(0, num, body_fun, 0)
|
||||
|
||||
cfun = api.jit(count)
|
||||
|
||||
self.assertEqual(count(2), 1)
|
||||
self.assertEqual(count(2), cfun(2))
|
||||
self.assertEqual(count(3), 3)
|
||||
self.assertEqual(count(3), cfun(3))
|
||||
self.assertEqual(count(4), 6)
|
||||
self.assertEqual(count(4), cfun(4))
|
||||
for args_maker in [lambda: [2], lambda: [3], lambda: [4]]:
|
||||
self._CompileAndCheck(count, args_maker, True)
|
||||
|
||||
def testForiLoopClosure(self):
|
||||
def count(num):
|
||||
|
Loading…
x
Reference in New Issue
Block a user