mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[XLA:TPU] Support scan loops for parameter input and output streaming in host offloading.
Currently, parameter input and output streaming use HloAliasAnalysis to correlate MoveToDevice calls with their corresponding input buffers. However, this can break down in scan loops, in which a dynamic slice creates the buffer which is offloaded. This prevents the AliasAnalysis from operating on the right buffer and finding the entry parameter. This change adds a function to TryParameterStreaming which traces up the call graph to potentially find a dynamic slice, and if so performs alias analysis on the input to that dynamic slice. For TryOutputStreaming, we trace down the call graph to find a dynamic update slice and perform alias analysis on that buffer instead. PiperOrigin-RevId: 626894899
This commit is contained in:
parent
e498bca223
commit
fd1007806f
@ -1152,6 +1152,26 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out_host, np_inp)
|
||||
self.assertEqual(out_host.sharding, s_host)
|
||||
|
||||
def test_parameter_streaming_inside_scan(self):
|
||||
mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z"))
|
||||
np_inp = np.arange(4096.0).reshape(16, 16, 16)
|
||||
s_host = NamedSharding(mesh, P("x", "y", "z"), memory_kind="pinned_host")
|
||||
arr_host = jax.device_put(np_inp, s_host)
|
||||
|
||||
@jax.jit
|
||||
def f(xs):
|
||||
def body(carry, x):
|
||||
x_tpu = jax.device_put(x, TransferToMemoryKind("device"))
|
||||
return carry, x_tpu + carry
|
||||
|
||||
return jax.lax.scan(body, 1.0, xs)
|
||||
|
||||
_, out_hbm = f(arr_host)
|
||||
self.assertArraysEqual(out_hbm, np_inp + 1.0)
|
||||
# Only expect the last dimension to have a named sharding.
|
||||
out_s = NamedSharding(mesh, P(None, None, "z"), memory_kind="device")
|
||||
self.assertEqual(out_hbm.sharding, out_s)
|
||||
|
||||
|
||||
class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user