[XLA:TPU] Support scan loops for parameter input and output streaming in host offloading.

Currently, parameter input and output streaming use HloAliasAnalysis to correlate MoveToDevice calls with their corresponding input buffers. However, this can break down in scan loops, in which a dynamic slice creates the buffer which is offloaded. This prevents the AliasAnalysis from operating on the right buffer and finding the entry parameter.

This change adds a function to TryParameterStreaming which traces up the call graph to potentially find a dynamic slice, and if so performs alias analysis on the input to that dynamic slice. For TryOutputStreaming, we trace down the call graph to find a dynamic update slice and perform alias analysis on that buffer instead.

PiperOrigin-RevId: 626894899
This commit is contained in:
Jackson Stokes 2024-04-21 20:02:24 -07:00 committed by jax authors
parent e498bca223
commit fd1007806f

View File

@ -1152,6 +1152,26 @@ class DevicePutTest(jtu.JaxTestCase):
self.assertArraysEqual(out_host, np_inp)
self.assertEqual(out_host.sharding, s_host)
def test_parameter_streaming_inside_scan(self):
mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z"))
np_inp = np.arange(4096.0).reshape(16, 16, 16)
s_host = NamedSharding(mesh, P("x", "y", "z"), memory_kind="pinned_host")
arr_host = jax.device_put(np_inp, s_host)
@jax.jit
def f(xs):
def body(carry, x):
x_tpu = jax.device_put(x, TransferToMemoryKind("device"))
return carry, x_tpu + carry
return jax.lax.scan(body, 1.0, xs)
_, out_hbm = f(arr_host)
self.assertArraysEqual(out_hbm, np_inp + 1.0)
# Only expect the last dimension to have a named sharding.
out_s = NamedSharding(mesh, P(None, None, "z"), memory_kind="device")
self.assertEqual(out_hbm.sharding, out_s)
class ActivationOffloadingTest(jtu.JaxTestCase):