rocm_jax/tests/memories_test.py
2025-03-18 13:01:48 -07:00

1896 lines
63 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import math
import re
from absl.testing import absltest
from absl.testing import parameterized
from absl import flags
import unittest
import jax
from jax import lax
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.layout import DeviceLocalLayout as DLL, Layout
from jax._src import config
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.numpy as jnp
from jax.ad_checkpoint import Offloadable, remat, Recompute
from jax._src.sharding import common_devices_indices_map
from jax._src.sharding_impls import (NamedSharding, PositionalSharding,
SingleDeviceSharding, GSPMDSharding,
TransferToMemoryKind, PartitionSpec as P)
from jax.experimental.compute_on import compute_on
from jax.experimental.shard_map import shard_map
import numpy as np
config.parse_flags_with_absl()
FLAGS = flags.FLAGS
def get_memory_kinds_from_executable(f, args):
compiled = f.lower(*args).compile()
return compiled.runtime_executable().get_output_memory_kinds()[0]
def _create_inputs(shape, pspec, mem_kind=None):
mesh = jtu.create_mesh((2, 2), ("x", "y"))
np_inp = np.arange(math.prod(shape)).reshape(shape)
s = NamedSharding(mesh, pspec, memory_kind=mem_kind)
inp = jax.device_put(np_inp, s)
return mesh, s, np_inp, inp
class ShardingMemoriesTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if jtu.test_device_matches(["cpu"]):
self._default_memory_kind = "unpinned_host"
else:
self._default_memory_kind = "device"
@parameterized.named_parameters(
("named_sharding", "named_sharding"),
("positional_sharding", "positional_sharding"),
("single_device_sharding", "single_device_sharding"),
("gspmd_sharding", "gspmd_sharding"),
)
def test_canonicalize_memory_kind(self, name):
if name == "named_sharding":
mesh = jtu.create_mesh((1,), "x")
ns = NamedSharding(mesh, P("x"))
self.assertEqual(ns.memory_kind, self._default_memory_kind)
elif name == "positional_sharding":
ps = PositionalSharding(jax.devices())
self.assertEqual(ps.memory_kind, self._default_memory_kind)
elif name == "single_device_sharding":
ss = SingleDeviceSharding(jax.devices()[0])
self.assertEqual(ss.memory_kind, self._default_memory_kind)
else:
assert name == "gspmd_sharding"
gs = GSPMDSharding.get_replicated(jax.devices())
self.assertEqual(gs.memory_kind, self._default_memory_kind)
@parameterized.named_parameters(
("named_sharding", "named_sharding"),
("positional_sharding", "positional_sharding"),
("single_device_sharding", "single_device_sharding"),
("gspmd_sharding", "gspmd_sharding"),
)
def test_wrong_memory_kind(self, name):
if name == "named_sharding":
with self.assertRaisesRegex(
ValueError, "Could not find memory addressable by device.*"
):
mesh = jtu.create_mesh((1,), ("x",))
NamedSharding(mesh, P("x"), memory_kind="hbm")
elif name == "positional_sharding":
with self.assertRaisesRegex(
ValueError, "Could not find memory addressable by device.*"
):
PositionalSharding(jax.devices(), memory_kind="gpu_hbm")
elif name == "single_device_sharding":
with self.assertRaisesRegex(
ValueError,
"Could not find memory addressable by device.*Device.*"
" can address the following memory kinds.*",
):
SingleDeviceSharding(jax.devices()[0], memory_kind="host")
else:
assert name == "gspmd_sharding"
with self.assertRaisesRegex(
ValueError, "Could not find memory addressable by device.*"
):
GSPMDSharding.get_replicated(jax.devices(), memory_kind="my_host")
@parameterized.named_parameters(
("named_sharding", "named_sharding"),
("positional_sharding", "positional_sharding"),
("single_device_sharding", "single_device_sharding"),
("gspmd_sharding", "gspmd_sharding"),
)
def test_correct_tpu_memory_kind(self, name):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("TPU memory kind test.")
if name == "named_sharding":
mesh = jtu.create_mesh((1,), ("x",))
NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
elif name == "positional_sharding":
PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind)
elif name == "single_device_sharding":
SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host")
else:
assert name == "gspmd_sharding"
GSPMDSharding.get_replicated(jax.devices(), memory_kind="unpinned_host")
@parameterized.named_parameters(
("named_sharding", "named_sharding"),
("positional_sharding", "positional_sharding"),
("single_device_sharding", "single_device_sharding"),
("gspmd_sharding", "gspmd_sharding"),
)
def test_sharding_eq(self, name):
if name == "named_sharding":
mesh = jtu.create_mesh((1,), ("x",))
s1 = NamedSharding(mesh, P("x"))
s2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
self.assertEqual(s1, s2)
elif name == "positional_sharding":
s1 = PositionalSharding(jax.devices())
s2 = PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind)
self.assertEqual(s1, s2)
elif name == "single_device_sharding":
s1 = SingleDeviceSharding(jax.devices()[0])
s2 = SingleDeviceSharding(jax.devices()[0], memory_kind=self._default_memory_kind)
self.assertEqual(s1, s2)
elif name == "gspmd_sharding":
s1 = GSPMDSharding.get_replicated(jax.devices())
s2 = GSPMDSharding.get_replicated(jax.devices(), memory_kind=self._default_memory_kind)
self.assertEqual(s1, s2)
def test_sharding_equivalent(self):
mesh = jtu.create_mesh((1,), ("x",))
ndim = 2
ns1 = NamedSharding(mesh, P("x"))
gs1 = GSPMDSharding(
tuple(mesh.devices.flat),
ns1._to_xla_hlo_sharding(ndim),
memory_kind=self._default_memory_kind,
)
self.assertTrue(ns1.is_equivalent_to(gs1, ndim))
ns2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
gs2 = GSPMDSharding(
tuple(mesh.devices.flat), ns2._to_xla_hlo_sharding(ndim)
)
self.assertTrue(ns2.is_equivalent_to(gs2, ndim))
def test_default_memory_kind(self):
dev = jax.devices()[0]
self.assertEqual(dev.default_memory().kind, self._default_memory_kind)
class DevicePutTest(jtu.JaxTestCase):
def setUp(self):
if not jtu.test_device_matches(["tpu", "gpu"]):
self.skipTest("Memories do not work on CPU backend yet.")
super().setUp()
def _check_device_put_addressable_shards(
self, out, inp, expected_sharding, expected_mem_kind, index=True):
self.assertArraysEqual(out, inp)
self.assertEqual(out.sharding, expected_sharding)
self.assertEqual(out.sharding.memory_kind, expected_mem_kind)
for s in out.addressable_shards:
if index:
self.assertArraysEqual(s.data, inp[s.index])
else:
self.assertArraysEqual(s.data, inp)
self.assertEqual(s.data.sharding.memory_kind, expected_mem_kind)
def test_error_transfer_to_memory_kind_outside_jit(self):
with self.assertRaisesRegex(
ValueError,
"TransferToMemoryKind argument to jax.device_put can only be used"
" inside jax.jit"):
jax.device_put(np.arange(16), TransferToMemoryKind("device"))
@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_host_to_hbm(self, host_memory_kind: str):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
mesh = jtu.create_mesh((2, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind)
np_inp = np.arange(16).reshape(8, 2)
out_on_host = jax.device_put(np_inp, s_host)
self.assertEqual(out_on_host.sharding, s_host)
s_hbm = s_host.with_memory_kind("device")
out_on_hbm = jax.device_put(out_on_host, s_hbm)
self._check_device_put_addressable_shards(
out_on_hbm, np_inp, s_hbm, "device")
@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_hbm_to_host(self, host_memory_kind: str):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
mesh = jtu.create_mesh((2, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind)
inp = jnp.arange(16).reshape(8, 2)
out_on_host = jax.device_put(inp, s_host)
self._check_device_put_addressable_shards(
out_on_host, inp, s_host, host_memory_kind)
sharded_inp = jax.device_put(inp, s_host.with_memory_kind("device"))
sharded_out_on_host = jax.device_put(sharded_inp, s_host)
self._check_device_put_addressable_shards(
sharded_out_on_host, sharded_inp, s_host, host_memory_kind)
@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_different_device_and_memory_host_to_hbm(
self, host_memory_kind: str
):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
if jax.device_count() < 3:
raise unittest.SkipTest("Test requires >=3 devices")
out_host0 = jax.device_put(
jnp.arange(8),
SingleDeviceSharding(jax.devices()[0], memory_kind=host_memory_kind))
dev2 = jax.devices()[2]
out_hbm1 = jax.device_put(
out_host0, SingleDeviceSharding(dev2, memory_kind="device"))
self.assertEqual(out_hbm1.sharding.memory_kind, "device")
self.assertEqual(out_hbm1.sharding._device, dev2)
self.assertEqual(out_hbm1.addressable_shards[0].data.sharding._device, dev2)
self.assertEqual(
out_hbm1.addressable_shards[0].data.sharding.memory_kind, "device")
@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_different_device_and_memory_hbm_to_host(
self, host_memory_kind: str
):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
if jax.device_count() < 3:
raise unittest.SkipTest("Test requires >=3 devices")
out_hbm0 = jnp.arange(8)
dev2 = jax.devices()[2]
out_host1 = jax.device_put(
out_hbm0, SingleDeviceSharding(dev2, memory_kind=host_memory_kind))
self.assertEqual(out_host1.sharding.memory_kind, host_memory_kind)
self.assertEqual(out_host1.sharding._device, dev2)
self.assertEqual(out_host1.addressable_shards[0].data.sharding._device,
dev2)
self.assertEqual(
out_host1.addressable_shards[0].data.sharding.memory_kind,
host_memory_kind)
@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_on_different_device_with_the_same_memory_kind(
self, host_memory_kind: str):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
if len(jax.devices()) < 2:
raise unittest.SkipTest("Test requires >=2 devices.")
np_inp = np.arange(16).reshape(8, 2)
s_hbm_dev_0 = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
s_hbm_dev_1 = SingleDeviceSharding(jax.devices()[1], memory_kind="device")
inp_hbm_dev0 = jax.device_put(np_inp, s_hbm_dev_0)
out_hbm_dev_1 = jax.device_put(inp_hbm_dev0, s_hbm_dev_1)
self._check_device_put_addressable_shards(
out_hbm_dev_1, np_inp, s_hbm_dev_1, "device")
inp_host_dev0 = jax.device_put(
np_inp, s_hbm_dev_0.with_memory_kind(host_memory_kind))
s_host_dev_1 = s_hbm_dev_1.with_memory_kind(host_memory_kind)
out_host_dev_1 = jax.device_put(inp_host_dev0, s_host_dev_1)
self._check_device_put_addressable_shards(
out_host_dev_1, np_inp, s_host_dev_1, host_memory_kind)
# TODO(yashkatariya): Enable this once we can compute on host.
# def test_device_put_resharding(self):
# mesh = jtu.create_mesh((2, 2), ("x", "y"))
# s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host")
# s_hbm = s_host.with_memory_kind("device")
# np_inp = np.arange(16).reshape(8, 2)
# # Reshard single device array on HBM to multi device on host
# sds_inp_hbm = jax.device_put(
# jnp.arange(16).reshape(8, 2),
# SingleDeviceSharding(jax.devices()[0], memory_kind="device"))
# # device_put on host
# out_sharded_host = jax.device_put(sds_inp_hbm, s_host)
# self._check_device_put_addressable_shards(
# out_sharded_host, np_inp, s_host, "unpinned_host")
# # Reshard single device array on host to multi device on hbm
# sds_inp_host = jax.device_put(
# jnp.arange(16).reshape(8, 2),
# SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host"))
# # device_put on hbm
# out_sharded_hbm = jax.device_put(sds_inp_host, s_hbm)
# self._check_device_put_addressable_shards(
# out_sharded_hbm, np_inp, s_hbm, "device")
@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_numpy_array(self, host_memory_kind: str):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
mesh = jtu.create_mesh((2, 2), ("x", "y"))
np_inp = np.arange(16).reshape(8, 2)
s_hbm = NamedSharding(mesh, P(("x", "y")), memory_kind="device")
s_host = s_hbm.with_memory_kind(host_memory_kind)
out_hbm = jax.device_put(np_inp, s_hbm)
self._check_device_put_addressable_shards(out_hbm, np_inp, s_hbm, "device")
out_host = jax.device_put(np_inp, s_host)
self._check_device_put_addressable_shards(
out_host, np_inp, s_host, host_memory_kind)
@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_numpy_scalar(self, host_memory_kind: str):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
np_inp = np.float32(8)
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
s_host = s_hbm.with_memory_kind(host_memory_kind)
out_hbm = jax.device_put(np_inp, s_hbm)
self._check_device_put_addressable_shards(out_hbm, np_inp, s_hbm, "device")
out_host = jax.device_put(np_inp, s_host)
self._check_device_put_addressable_shards(
out_host, np_inp, s_host, host_memory_kind)
@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_python_scalar(self, host_memory_kind: str):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
py_scalar = float(8)
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
s_host = s_hbm.with_memory_kind(host_memory_kind)
out_hbm = jax.device_put(py_scalar, s_hbm)
self._check_device_put_addressable_shards(
out_hbm, py_scalar, s_hbm, "device", index=False)
out_host = jax.device_put(py_scalar, s_host)
self._check_device_put_addressable_shards(
out_host, py_scalar, s_host, host_memory_kind, index=False)
@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_python_int(self, host_memory_kind: str):
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
self.skipTest("unpinned_host does not work on GPU backend.")
py_inp = 8
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
s_host = s_hbm.with_memory_kind(host_memory_kind)
out_hbm = jax.device_put(py_inp, s_hbm)
self._check_device_put_addressable_shards(
out_hbm, py_inp, s_hbm, "device", index=False)
out_host = jax.device_put(py_inp, s_host)
self._check_device_put_addressable_shards(
out_host, py_inp, s_host, host_memory_kind, index=False)
def test_device_put_inside_jit(self):
_, s_host, np_inp, inp_host = _create_inputs(
(8, 2), P("x", "y"), mem_kind="pinned_host")
s_dev = s_host.with_memory_kind("device")
@jax.jit
def f(a, b):
x, y = jax.device_put((a, b), s_dev)
return x * y
out = f(inp_host, inp_host)
self._check_device_put_addressable_shards(
out, np_inp * np_inp, s_dev, "device")
def test_parameter_streaming(self):
_, s_host, np_inp, inp_host = _create_inputs(
(8, 2), P("x", "y"), mem_kind="pinned_host")
s_dev = s_host.with_memory_kind('device')
inp_dev = jax.device_put(np_inp, s_dev)
@functools.partial(jax.jit, out_shardings=s_host)
def f(a, b):
x = b * 2
y = jax.device_put(a, s_dev)
z = x * y
return z * 4, z
compiled = f.lower(inp_host, inp_dev).compile() # doesn't crash
compiled_text = compiled.as_text()
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")
out1, out2 = f(inp_host, inp_dev)
self._check_device_put_addressable_shards(
out1, np_inp * np_inp * 8, s_host, 'pinned_host')
self._check_device_put_addressable_shards(
out2, np_inp * np_inp * 2, s_host, 'pinned_host')
def test_zero_size_parameter(self):
if jtu.test_device_matches(["gpu"]):
self.skipTest("This test does not work on GPU backend.")
_, s_host, np_inp, inp_host = _create_inputs(
(0,), P(), mem_kind="pinned_host")
s_dev = s_host.with_memory_kind('device')
@functools.partial(jax.jit, out_shardings=s_host)
def f(a):
b = jax.device_put(a, s_dev)
return b
compiled = f.lower(inp_host).compile() # doesn't crash
compiled_text = compiled.as_text()
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")
out = f(inp_host)
self._check_device_put_addressable_shards(
out, np_inp, s_host, 'pinned_host')
def test_parameter_streaming_with_scalar_and_constant(self):
mesh = jtu.create_mesh((2, 2), ("x", "y"))
scalar_inp = 1
s_host = NamedSharding(mesh, P(), memory_kind="pinned_host")
@functools.partial(jax.jit, out_shardings=s_host)
def f(scalar_input):
y = jax.device_put(scalar_input, s_host)
z = 2
w = jax.device_put(z, s_host)
return y, w
compiled = f.lower(scalar_inp).compile() # doesn't crash
compiled_text = compiled.as_text()
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")
out1, out2 = f(scalar_inp)
self._check_device_put_addressable_shards(
out1, scalar_inp, s_host, "pinned_host", index=False
)
self._check_device_put_addressable_shards(
out2, 2, s_host, "pinned_host", index=False
)
def test_parameter_and_output_streaming_with_array(self):
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
self.skipTest("This test requires an xla_version >= 2.")
mesh = jtu.create_mesh((2, 2), ("x", "y"))
np_inp = np.arange(16).reshape(8, 2)
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host")
inp_host = jax.device_put(np_inp, s_host)
@functools.partial(jax.jit, out_shardings=(s_host, s_host))
def f(x):
return (x, x)
compiled = f.lower(inp_host).compile() # doesn't crash
compiled_text = compiled.as_text()
if compiled_text is not None:
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")
out1, out2 = f(inp_host)
self._check_device_put_addressable_shards(
out1, np_inp, s_host, "pinned_host"
)
self._check_device_put_addressable_shards(
out2, np_inp, s_host, "pinned_host"
)
def test_parameter_and_output_streaming_with_scalar(self):
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
self.skipTest("This test requires an xla_version >= 2.")
mesh = jax.sharding.Mesh(jax.devices(), "axis")
s_host = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec(), memory_kind="pinned_host"
)
scalar_inp = 1
@functools.partial(jax.jit, out_shardings=(s_host, s_host))
def f(x):
return (x, x)
compiled = f.lower(scalar_inp).compile() # doesn't crash
compiled_text = compiled.as_text()
if compiled_text is not None:
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")
out1, out2 = f(scalar_inp)
self._check_device_put_addressable_shards(
out1, scalar_inp, s_host, "pinned_host", index=False
)
self._check_device_put_addressable_shards(
out2, scalar_inp, s_host, "pinned_host", index=False
)
def test_identity_jit_host_to_device_and_vice_versa(self):
mesh = jtu.create_mesh((2, 2), ("x", "y"))
np_inp = np.arange(16).reshape(8, 2)
s_host = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')
s_dev = s_host.with_memory_kind('device')
arr_host = jax.device_put(np_inp, s_host)
arr_dev = jax.device_put(np_inp, s_dev)
# pinned_host -> device
f = jax.jit(lambda x: x, out_shardings=s_dev)
out_dev = f(arr_host)
self.assertArraysEqual(out_dev, np_inp)
self.assertEqual(out_dev.sharding, s_dev)
# device -> pinned_host
g = jax.jit(lambda x: x, out_shardings=s_host)
out_host = g(arr_dev)
self.assertArraysEqual(out_host, np_inp)
self.assertEqual(out_host.sharding, s_host)
def test_parameter_streaming_inside_scan(self):
mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z"))
np_inp = np.arange(4096.0).reshape(16, 16, 16)
s_host = NamedSharding(mesh, P("x", "y", "z"), memory_kind="pinned_host")
arr_host = jax.device_put(np_inp, s_host)
@jax.jit
def f(xs):
def body(carry, x):
x_tpu = jax.device_put(x, TransferToMemoryKind("device"))
return carry, x_tpu + carry
return jax.lax.scan(body, 1.0, xs)
_, out_hbm = f(arr_host)
self.assertArraysEqual(out_hbm, np_inp + 1.0)
# Only expect the last dimension to have a named sharding.
out_s = NamedSharding(mesh, P(None, None, "z"), memory_kind="device")
self.assertEqual(out_hbm.sharding, out_s)
def test_output_streaming(self):
mesh = jtu.create_mesh((1, 1), ("x", "y"))
np_inp = np.arange(16.0).reshape(8, 2)
s_hbm = NamedSharding(mesh, P("x", "y"), memory_kind="device")
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host")
arr_hbm = jax.device_put(np_inp, s_hbm)
@functools.partial(jax.jit, out_shardings=s_host)
def f(xs):
out_tpu = xs + 1.0
return out_tpu
out_host = f(arr_hbm)
self.assertArraysEqual(out_host, np_inp + 1.0)
self.assertEqual(out_host.sharding, s_host)
def test_weight_offload_with_dp_on_output(self):
_, s_dev, np_inp, inp_dev = _create_inputs(
(8, 2), P("x", "y"), mem_kind="device")
s_host = s_dev.with_memory_kind('pinned_host')
@jax.jit
def f(x):
x = x * 2
y = jax.device_put(x, s_host)
return y
out_host = f(inp_dev)
self._check_device_put_addressable_shards(
out_host, np_inp * 2, s_host, 'pinned_host')
def test_output_streaming_inside_scan(self):
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
self.skipTest("This test requires an xla_version >= 2.")
mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z"))
np_inp = np.arange(4096).reshape(16, 16, 16)
s_hbm = NamedSharding(mesh, P(None, "y", "z"), memory_kind="device")
arr_hbm = jax.device_put(np_inp, s_hbm)
@jax.jit
def f(xs):
def body(carry, x):
out_tpu = x + carry
return carry, jax.device_put(
out_tpu, NamedSharding(mesh, P("y", "z"), memory_kind="pinned_host"))
_, res = jax.lax.scan(body, 1, xs)
return res
out = f(arr_hbm)
self.assertArraysEqual(out, np_inp + 1)
self.assertEqual(out.sharding.memory_kind, 'pinned_host')
def test_deepcopy(self):
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
self.skipTest("This test requires an xla_version >= 2.")
mesh = jax.sharding.Mesh(jax.devices(), "x")
s_host = NamedSharding(mesh, P(), memory_kind="pinned_host")
t = jax.device_put(jnp.zeros((8, 2)), s_host)
t_copy = copy.deepcopy(t)
self.assertArraysEqual(t, t_copy)
self.assertEqual(t.shape, t_copy.shape)
def test_close_over_host_constant_and_stream(self):
_, s_host, np_inp, inp_host = _create_inputs(
(8, 2), P("x", "y"), mem_kind="pinned_host")
s_dev = s_host.with_memory_kind('device')
@functools.partial(jax.jit, out_shardings=s_dev)
def f():
y = jax.device_put(inp_host, s_dev)
z = y * 2
return z
out = f()
self._check_device_put_addressable_shards(out, np_inp * 2, s_dev, 'device')
@jtu.run_on_devices('tpu')
def test_ragged_copy_on_host(self):
mesh = jtu.create_mesh((2,), ('x'))
sharding = jax.sharding.NamedSharding(mesh, P(('x')))
cpu_sharding = sharding.with_memory_kind('pinned_host')
num_pages = 512 * 1024
page_size = 1024
x = jnp.full((num_pages, page_size), 1, dtype=jnp.bfloat16, device=sharding)
def write(x):
return x.at[16 * 1024:].set(0)
x = shard_map(write, mesh, P(('x'),), P(('x')))(x)
chunk_size = 8
def inner(state):
idx, x, output = state
chunk = jax.lax.dynamic_slice_in_dim(x, idx * chunk_size, chunk_size)
chunk_host = jax.device_put(chunk, TransferToMemoryKind('pinned_host'))
output = jax.lax.dynamic_update_slice_in_dim(
output, chunk_host, idx * chunk_size, axis=0)
return (idx + 1, x, output)
def cond(state):
idx, x, _ = state
chunk = jax.lax.dynamic_slice_in_dim(x, idx * chunk_size, chunk_size)
return (idx * chunk_size < x.shape[0]) & jnp.any(chunk > 0)
def foo(x):
output = jnp.zeros_like(x, device=cpu_sharding)
_, _, cpu_x = jax.lax.while_loop(cond, inner, (0, x, output))
return cpu_x
fn = jax.jit(shard_map(foo, mesh, P(('x'),), P(('x')),
check_rep=False),
out_shardings=cpu_sharding)
y = fn(x)
jax.block_until_ready(y)
compiled_text = fn.lower(x).compile().as_text()
if compiled_text is not None:
self.assertIn('custom_call_target="AllocateBuffer"', compiled_text)
def test_disallow_alias_copies_arrays(self):
mesh = jtu.create_mesh((2,), ("x",))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P("x"), memory_kind="pinned_host")
inp_host = jax.device_put(np_inp, s)
inp_host_copy = jax.device_put(inp_host, may_alias=False)
for a in jax.tree.leaves(inp_host):
a.delete()
jax.block_until_ready(inp_host_copy)
def test_disallow_alias_copies_arrays_with_donated_input(self):
mesh = jtu.create_mesh((2,), ("x",))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P("x"), memory_kind="pinned_host")
inp_host = jax.device_put(np_inp, s)
inp_host_donate = jax.jit(lambda x: x, donate_argnums=0)(inp_host)
inp_host_donate_copy = jax.device_put(inp_host_donate, may_alias=False)
for a in jax.tree.leaves(inp_host_donate):
a.delete()
jax.block_until_ready(inp_host_donate_copy)
class ComputeOffload(jtu.BufferDonationTestCase):
def setUp(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Memories do not work on CPU and GPU backends yet.")
super().setUp()
def _check_mem_kind(self, executable_kind, out_sharding, expected_kind):
out_kind = out_sharding.memory_kind
self.assertEqual(executable_kind, out_kind)
self.assertEqual(out_kind, expected_kind)
self.assertEqual(executable_kind, expected_kind)
def test_compute_no_inputs(self):
mesh = jtu.create_mesh((4,), ('data'))
tpu_sharding = NamedSharding(mesh, P('data'))
cpu_sharding = NamedSharding(mesh, P('data'), memory_kind='pinned_host')
@functools.partial(jax.jit, out_shardings=(tpu_sharding, cpu_sharding))
def init():
tpu_array = jax.random.normal(jax.random.key(42), (16,16))
cpu_array = jax.random.normal(jax.random.key(42), (16,16))
return tpu_array, cpu_array
tpu_array, cpu_array = init()
self.assertEqual(tpu_array.sharding, tpu_sharding)
self.assertEqual(cpu_array.sharding, cpu_sharding)
def test_compute_no_inputs_host_replicated(self):
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3:
self.skipTest("This test requires an xla_version >= 3.")
if config.use_shardy_partitioner.value:
self.skipTest("XLA failure due to b/370786664 and b/366411266. "
"Enable when fixed.")
mesh = jtu.create_mesh((4,), ('data'))
tpu_sharding = NamedSharding(mesh, P('data'))
cpu_sharding = NamedSharding(mesh, P(), memory_kind='pinned_host')
@functools.partial(jax.jit, out_shardings=(tpu_sharding, cpu_sharding))
def init():
tpu_array = jax.random.normal(jax.random.key(42), (16, 16))
cpu_array = jax.random.normal(jax.random.key(42), (16, 16))
return tpu_array, cpu_array
tpu_array, cpu_array = init()
self.assertEqual(tpu_array.sharding, tpu_sharding)
self.assertEqual(cpu_array.sharding, cpu_sharding)
def test_compute_on_basic(self):
out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host')
@compute_on('device_host')
@jax.jit
def g(x):
return x * 2
@jax.jit
def f(x):
y = g(x)
return y * 3
inp = jnp.arange(8)
out = f(inp)
self.assertArraysEqual(out, inp * 6)
lowered_text = f.lower(jnp.arange(8)).as_text()
self.assertIn('_xla_compute_type', lowered_text)
@functools.partial(jax.jit, out_shardings=out_s)
def h(x):
y = g(x)
return y * 3
out2 = h(inp)
self.assertArraysEqual(out2, inp * 6)
self.assertEqual(out2.sharding.memory_kind, 'pinned_host')
def test_compute_on_host_shared_sharding(self):
mesh = jtu.create_mesh((2,), ("x"))
device_sharding = NamedSharding(mesh, P("x"))
host_sharding = device_sharding.with_memory_kind("pinned_host")
@compute_on("device_host")
@functools.partial(
jax.jit,
in_shardings=(host_sharding, device_sharding),
out_shardings=(host_sharding, device_sharding),
donate_argnums=(0, 1),
)
def host_func(x, y):
return (x * y), ((x**2) * (y**2))
@functools.partial(
jax.jit,
in_shardings=(host_sharding, device_sharding),
out_shardings=(host_sharding, device_sharding),
donate_argnums=(0),
)
def device_func(host_data, device_data):
host_data, device_data = host_func(host_data, device_data)
device_data = device_data * 2
host_data, device_data = host_func(host_data, device_data)
return (host_data, device_data)
input_x = jnp.ones(8)
input_host = jax.device_put(input_x, host_sharding)
input_device = jnp.arange(8)
input_device = jnp.where(input_device < 4, 0, 1)
input_device = jax.device_put(input_device, device_sharding)
output_host, output_device = device_func(input_host, input_device)
self.assertEqual(output_host.sharding.memory_kind, 'pinned_host')
self.assertEqual(output_device.sharding.memory_kind, 'device')
self.assertArraysEqual(output_host, [0., 0., 0., 0., 2., 2., 2., 2.])
self.assertArraysEqual(output_device, [0., 0., 0., 0., 4., 4., 4., 4.])
def test_compute_on_basic_inline(self):
@compute_on('device_host')
@jax.jit
def g(x):
return x * 2
@functools.partial(jax.jit, inline=True)
def h(x):
y = g(x)
return y * 3
@jax.jit
def f(x):
return h(x)
inp = jnp.arange(8)
out = f(inp)
self.assertArraysEqual(out, inp * 6)
lowered_text = f.lower(jnp.arange(8)).as_text('hlo')
self.assertRegex(lowered_text,
'to_apply=g.*frontend_attributes={_xla_compute_type="host"}')
def test_compute_on_reduction(self):
out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host')
@compute_on('device_host')
@jax.jit
def g(x):
# Reduction generates multiple host computations (inside a single host
# computation module): the main one and a reduction body.
return jnp.sum(x)
@jax.jit
def f(x):
y = g(x)
z = jnp.sum(x)
return y * z
inp = jnp.arange(8)
out = f(inp)
self.assertArraysEqual(out, np.sum(inp) * np.sum(inp))
lowered_text = f.lower(jnp.arange(8)).as_text()
self.assertIn('_xla_compute_type', lowered_text)
@functools.partial(jax.jit, out_shardings=out_s)
def h(x):
y = g(x)
z = jnp.sum(x)
return y * z
out2 = h(inp)
self.assertArraysEqual(out2, np.sum(inp) * np.sum(inp))
self.assertEqual(out2.sharding.memory_kind, 'pinned_host')
def test_compute_host_loop(self):
# TODO(apaszke): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
self.skipTest("Requires libtpu built after 2024-12-19")
@compute_on('device_host')
@jax.jit
def fn():
k = jax.random.key(0)
return jax.nn.initializers.lecun_normal()(k, (2, 2), jnp.float32)
fn() # doesn't crash
@compute_on('device_host')
def fn():
k = jax.random.key(0)
return jax.nn.initializers.lecun_normal()(k, (2, 2), jnp.float32)
fn() # doesn't crash
def test_nested_compute_error(self):
@compute_on('device')
@jax.jit
def f0(x):
return x * 2
@compute_on('device_host')
@jax.jit
def f1(x):
return f0(x)
@jax.jit
def f2(x):
return f1(x)
with self.assertRaisesRegex(
NotImplementedError,
"Nesting `compute_on` with different compute types is not supported"
" yet."):
f2(jnp.arange(8))
def test_compute_on_grad(self):
@compute_on('device_host')
@jax.jit
def g(x):
return jnp.sin(x)
def f(x):
y = g(x)
return jnp.sum(y)
inp = jnp.arange(8.)
jf = jax.jit(jax.grad(f))
jtu.check_grads(jf, (inp,), order=2)
lowered_text = jf.lower(inp).as_text('hlo')
out = re.findall(r"call.*to_apply.*_xla_compute_type", lowered_text)
self.assertLen(out, 2)
def test_compute_on_remat(self):
inp = jnp.arange(16.)
def policy(prim, *avals, **params):
return Recompute
@compute_on('device_host')
@jax.jit
def g(x):
x = jnp.sin(x)
x = jnp.sin(x)
x = jnp.sin(x)
return x
@functools.partial(remat, policy=policy)
def f(x):
x = g(x)
return jnp.sum(x)
# Execution test.
jf = jax.jit(jax.grad(f))
jf(inp) # doesn't crash
lowered_text = jf.lower(inp).as_text('hlo')
out = re.findall(r"call.*to_apply.*_xla_compute_type", lowered_text)
self.assertLen(out, 2)
def test_nested_no_op_compute(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, s)
@compute_on('device_host')
@jax.jit
def f0(x):
return x * 2
@compute_on('device_host')
@jax.jit
def f1(x):
x = x * 3
return f0(x)
@jax.jit
def f2(x):
return f1(x)
out = f2(arr)
self.assertArraysEqual(out, arr * 6)
self.assertEqual(out.sharding, s)
def test_sharded_compute_on_host(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, s)
@compute_on('device_host')
@jax.jit
def g(x, y):
return x * y
@jax.jit
def f(x):
x = x * 3
return g(x, x)
out = f(arr)
expected_out = (np_inp * 3) * (np_inp * 3)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, expected_out)
def test_host_offload_in_custom_vjp(self):
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
self.skipTest("This test requires an xla_version >= 2.")
@jax.custom_vjp
def f(x):
return jnp.sin(x)
@compute_on('device_host')
@jax.jit
def eq(x, y):
return (x == y).astype(jnp.float32)
def f_fwd(x):
y = x * 2
z = jax.device_put(y, TransferToMemoryKind('pinned_host'))
return y, (x, z)
def f_bwd(res, tx):
x, z = res
y = x * 2
z2 = jax.device_put(y, TransferToMemoryKind('pinned_host'))
return (eq(z, z2),)
f.defvjp(f_fwd, f_bwd)
g = jax.jit(jax.grad(lambda x: f(x).sum()))
x = jnp.ones(3) * 4
all_true = jnp.ones(3)
self.assertArraysEqual(g(x), all_true)
def test_host_offload_in_custom_vjp_sharded(self):
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
self.skipTest("This test requires an xla_version >= 2.")
mesh = jtu.create_mesh((2, 2), ("x", "y"))
s = NamedSharding(mesh, P('x'))
@jax.custom_vjp
def f(x):
return jnp.sin(x)
@compute_on('device_host')
@jax.jit
def eq(x, y):
return (x == y).astype(jnp.float32)
def f_fwd(x):
y = x * 2
z = jax.device_put(y, s.with_memory_kind('pinned_host'))
return y, (x, z)
def f_bwd(res, tx):
x, z = res
y = x * 2
z2 = jax.device_put(y, s.with_memory_kind('pinned_host'))
return (eq(z, z2),)
f.defvjp(f_fwd, f_bwd)
g = jax.jit(jax.grad(lambda x: f(x).sum()))
arr = jax.device_put(jnp.ones(4) * 4, s)
all_true = jnp.ones(4)
self.assertArraysEqual(g(arr), all_true)
def test_scan_offload(self):
np_inp = jnp.arange(4096).reshape(16, 16, 16)
@jax.jit
def f(xs):
def body(carry, x):
with compute_on('device_host'):
out_tpu = x + carry
return carry, out_tpu
_, res = jax.lax.scan(body, 1, xs)
return res
out = f(np_inp)
self.assertArraysEqual(out, np_inp + 1)
@compute_on('device_host')
@jax.jit
def body2(carry, x):
out_tpu = x + carry
return carry, out_tpu
@jax.jit
def f2(xs):
_, res = jax.lax.scan(body2, 1, xs)
return res
out2 = f2(np_inp)
self.assertArraysEqual(out2, np_inp + 1)
@parameterized.parameters(True, False)
def test_copy_offload(self, jit_compute_fn: bool):
# test an explicit copy within the host computation.
def g(x):
return jnp.copy(x) * 2
@jax.jit
def f(x):
if jit_compute_fn:
y = compute_on("device_host")(jax.jit(g))(x)
else:
y = compute_on("device_host")(g)(x)
return y * 3
inp = jnp.arange(8)
out = f(inp)
self.assertArraysEqual(out, inp * 6)
lowered_text = f.lower(jnp.arange(8)).as_text()
self.assertIn('_xla_compute_type', lowered_text)
def test_pure_host_data_and_compute(self):
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
self.skipTest("This test requires an xla_version >= 2.")
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')
np_inp = np.arange(16).reshape(8, 2)
arr_host = jax.device_put(np_inp, s)
@compute_on('device_host')
@jax.jit
def g(x):
return x * x
@functools.partial(jax.jit, out_shardings=s)
def f(x):
return g(x)
out = f(arr_host)
self.assertEqual(out.sharding, s)
self.assertEqual(out.sharding.memory_kind, 'pinned_host')
self.assertArraysEqual(out, np_inp * np_inp)
def test_eager_compute(self):
inp = jnp.arange(8.)
with compute_on('device_host'):
out = inp * 2
out = jnp.sin(out)
self.assertArraysAllClose(out, jnp.sin(inp * 2))
def test_compute_per_annotation(self):
mesh = jtu.create_mesh((2, 2), ("x", "y"))
s = NamedSharding(mesh, P("x", "y"))
np_inp = np.arange(16.).reshape(8, 2)
arr = jax.device_put(np_inp, s)
@jax.jit
@compute_on('device_host')
def f(x):
return jnp.sin(x * 2)
# # sharded input
out = f(arr)
self.assertArraysAllClose(out, np.sin(np_inp * 2))
out2 = f(np_inp)
self.assertArraysAllClose(out2, np.sin(np_inp * 2))
def test_jit_host_multi_outputs(self):
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
self.skipTest("This test requires an xla_version >= 2.")
_, s, np_inp, inp = _create_inputs((8, 2), P("x"))
@jax.jit
def f(x, y):
x, y = jnp.sin(x), jnp.cos(y)
x = jax.device_put(x, s.with_memory_kind("pinned_host"))
y = jax.device_put(y, s.with_memory_kind("device"))
return x, y
out1, out2 = f(inp, inp)
self.assertArraysAllClose(out1, np.sin(np_inp))
self.assertArraysAllClose(out2, np.cos(np_inp))
self.assertEqual(out1.sharding, s.with_memory_kind("pinned_host"))
self.assertEqual(out2.sharding, s.with_memory_kind("device"))
def test_jit_out_shardings_single_output(self):
mesh, _, _, inp = _create_inputs((8, 2), P("x", "y"))
out_s = NamedSharding(mesh, P(), memory_kind="pinned_host")
@functools.partial(jax.jit, out_shardings=out_s)
def g(x):
return jnp.sum(x * 2)
out = g(inp)
self.assertEqual(out.sharding, out_s)
executable_mk = get_memory_kinds_from_executable(g, [inp])
self._check_mem_kind(executable_mk[0], out.sharding, "pinned_host")
@jax.jit
def h(x):
x = jnp.sum(x * 2)
out = jax.device_put(x, out_s)
return out
out = h(inp)
self.assertEqual(out.sharding, out_s)
executable_mk = get_memory_kinds_from_executable(h, [inp])
self._check_mem_kind(executable_mk[0], out.sharding, "pinned_host")
def test_jit_in_shardings(self):
_, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
@functools.partial(jax.jit, in_shardings=s.with_memory_kind("pinned_host"))
def f(x):
return x * 2
with self.assertRaisesRegex(
ValueError,
"Memory kinds passed to jax.jit does not match memory kind on the"
" respective arg. Got pjit memory kind: pinned_host, arg memory kind:"
" device for arg shape.*"):
f(jnp.arange(16).reshape(8, 2)) # uncommitted inp also raises error
with self.assertRaisesRegex(
ValueError,
"Memory kinds passed to jax.jit does not match memory kind on the"
" respective arg. Got pjit memory kind: pinned_host, arg memory kind:"
" device for arg shape.*"):
f(inp) # committed inp raises error.
@functools.partial(jax.jit, in_shardings=s.with_memory_kind("device"))
def g(x):
return x * 2
out = g(inp)
executable_kind = get_memory_kinds_from_executable(g, [inp])
self.assertArraysEqual(out, np_inp * 2)
self._check_mem_kind(executable_kind[0], out.sharding, "device")
def test_jit_in_out_shardings(self):
mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"), mem_kind="device")
out_s = NamedSharding(mesh, P(), memory_kind="device")
@functools.partial(jax.jit, in_shardings=s, out_shardings=out_s)
def f(x):
return jnp.sum(x)
out = f(inp)
executable_kind = get_memory_kinds_from_executable(f, [inp])
self.assertArraysEqual(out, np.sum(np_inp))
self._check_mem_kind(executable_kind[0], out.sharding, "device")
@functools.partial(
jax.jit,
in_shardings=s,
out_shardings=out_s.with_memory_kind("pinned_host"),
)
def g(x):
return jnp.sum(x)
out = g(inp)
executable_kind = get_memory_kinds_from_executable(g, [inp])
self.assertArraysEqual(out, np.sum(np_inp))
self._check_mem_kind(executable_kind[0], out.sharding, "pinned_host")
def test_device_put_different_devices(self):
_, _, _, inp = _create_inputs((8, 2), P("x", "y"))
@jax.jit
def f(x):
return jax.device_put(
x, SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host"))
with self.assertRaisesRegex(
ValueError, "Received incompatible devices for jitted computation"):
f(inp)
def test_jit_cpp_cache_hit(self):
mesh, _, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
inp2 = jax.device_put(
np_inp, NamedSharding(mesh, P("x", "y"), memory_kind="device"))
f = jax.jit(lambda x: x @ x.T)
with jtu.count_pjit_cpp_cache_miss() as count:
out = f(inp)
out2 = f(inp2)
self.assertEqual(count(), 1)
self.assertArraysEqual(out, np_inp @ np_inp.T)
self.assertArraysEqual(out2, np_inp @ np_inp.T)
def test_jit_compilation_cache_hit(self):
if config.use_shardy_partitioner.value:
self.skipTest("Shardy doesn't support GSPMDSharding")
mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
inp2 = jax.device_put(
np_inp, GSPMDSharding(tuple(mesh.devices.flat),
s._to_xla_hlo_sharding(inp.ndim),
memory_kind="device")
)
f = jax.jit(lambda x: x @ x.T)
with (jtu.count_pjit_cpp_cache_miss() as cpp_count,
jtu.count_jit_and_pmap_lowerings() as lowering_count):
f(inp)
f(inp2)
self.assertEqual(cpp_count(), 2)
self.assertEqual(lowering_count(), 2)
def test_jit_cpp_cache_output_hit(self):
_, _, _, inp = _create_inputs((8, 2), P("x"), mem_kind="device")
@jax.jit
def mul_two(x):
return x * 2
with jtu.count_pjit_cpp_cache_miss() as count:
out = mul_two(inp)
mul_two(out)
self.assertEqual(count(), 1)
def test_jit_cache_hit_with_default_and_specified_mem_kind(self):
_, s, np_inp, _ = _create_inputs((8, 2), P("x", "y"))
_, s2, np_inp2, _ = _create_inputs((8, 2), P("x", "y"), mem_kind="device")
def mul(x):
return x @ x.T
f = jax.jit(mul, in_shardings=s)
g = jax.jit(mul, in_shardings=s2)
with jtu.count_jit_and_pmap_lowerings() as count:
out = f(np_inp)
out2 = g(np_inp2)
self.assertEqual(count(), 1)
self.assertArraysEqual(out, np_inp @ np_inp.T)
self.assertArraysEqual(out2, np_inp2 @ np_inp2.T)
def test_sharding_devices_indices_map_cache_hit(self):
mesh = jtu.create_mesh((2, 2), ("x", "y"))
shape = (8, 2)
s1 = NamedSharding(mesh, P("x", "y"))
s2 = NamedSharding(mesh, P("x", "y"), memory_kind="device")
s1.devices_indices_map(shape)
cache_info1 = common_devices_indices_map.cache_info()
s2.devices_indices_map(shape)
cache_info2 = common_devices_indices_map.cache_info()
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
def test_no_donation_across_memory_kinds(self):
if xb.using_pjrt_c_api():
raise unittest.SkipTest("GetOutputShardings not supported in PJRT C API")
mesh = jtu.create_mesh((2, 1), ("x", "y"))
np_inp = np.arange(16).reshape(8, 2)
s_hbm = NamedSharding(mesh, P("x"))
s_host = s_hbm.with_memory_kind("pinned_host")
inp = jax.device_put(np_inp, s_hbm)
@functools.partial(jax.jit, out_shardings=s_host, donate_argnums=0)
def f(x):
return x * 2
with self.assertWarnsRegex(
UserWarning, "Some donated buffers were not usable"):
f(inp)
lowered_text = f.lower(inp).as_text("hlo")
self.assertNotIn("input_output_alias", lowered_text)
self.assertNotDeleted(inp)
def test_single_mem_kind_donation_default_mem_kind(self):
mesh = jtu.create_mesh((2,), "x")
s = NamedSharding(mesh, P())
@functools.partial(jax.jit, out_shardings=s, donate_argnums=0)
def f(inp1):
return inp1 * 2
x = jax.device_put(np.arange(16).reshape(8, 2), s)
f(x)
lowered_text = f.lower(x).as_text("hlo")
self.assertIn("input_output_alias", lowered_text)
self.assertDeleted(x)
def test_compute_offload_inside_shmap(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, s)
@compute_on('device_host')
@jax.jit
def g(x):
return x * 2
def f(x):
x = x * 3
y = g(x)
return y * 4
out = jax.jit(shard_map(f, mesh=mesh, in_specs=P('x', 'y'),
out_specs=P('x', 'y')))(arr)
self.assertArraysEqual(out, np_inp * 24)
def test_qr_decomposition_offload(self):
if jtu.is_cloud_tpu():
self.skipTest("Test fails on cloud TPU")
shape = (3, 3)
dtype = np.float32
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
@compute_on("device_host")
@jax.jit
def g(x):
return lax.linalg.qr(x, full_matrices=True)
@jax.jit
def f(x):
x, _ = lax.linalg.qr(x, full_matrices=True)
x, _ = g(x)
return x
out = f(operand) # doesn't crash
lowered_text = f.lower(operand).as_text()
self.assertIn('@lapack_sgeqrf', lowered_text)
self.assertIn('@Qr', lowered_text)
@jax.jit
def h(x):
x, _ = lax.linalg.qr(x, full_matrices=True)
x, _ = lax.linalg.qr(x, full_matrices=True)
return x
expected_out = h(operand)
self.assertArraysAllClose(out, expected_out, rtol=1e-3)
def test_mem_kind_donation_pinned_host(self):
mesh = jtu.create_mesh((2,), "x")
s = NamedSharding(mesh, P(), memory_kind='pinned_host')
s_dev = s.with_memory_kind('device')
@compute_on('device_host')
@functools.partial(jax.jit, out_shardings=(s, s_dev), donate_argnums=(0, 1))
def f(inp1, inp2):
return inp1 * 2, inp2 * 2
np_inp = np.arange(16).reshape(8, 2)
x = jax.device_put(np_inp, s)
x_dev = jax.device_put(np_inp, s_dev)
f(x, x_dev)
lowered_text = f.lower(x, x_dev).as_text("hlo")
self.assertIn("input_output_alias", lowered_text)
self.assertDeleted(x)
self.assertDeleted(x_dev)
@parameterized.parameters("pinned_host", "device")
def test_identity_mem_kind_donation(self, mem_kind):
mesh = jtu.create_mesh((2,), "x")
s = NamedSharding(mesh, P(), memory_kind=mem_kind)
@functools.partial(jax.jit, out_shardings=s, donate_argnums=0)
def f(inp):
return inp
np_inp = np.arange(16).reshape(8, 2)
x = jax.device_put(np_inp, s)
f(x)
lowered_text = f.lower(x).as_text("hlo")
self.assertIn("input_output_alias", lowered_text)
self.assertDeleted(x)
def test_compute_offload_with_donation(self):
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
p_sharding = jax.sharding.SingleDeviceSharding(
jax.devices()[0], memory_kind="pinned_host"
)
@compute_on("device_host")
@jax.jit
def host_fn(x_in, y_in):
return x_in * x_in, y_in + y_in
def test_fn(x_in, y_in):
x_out, y_out = host_fn(x_in, y_in)
return x_out, y_out
x = jnp.arange(0, 1024, dtype=jnp.float32)
y = jnp.arange(0, 1024, dtype=jnp.float32)
y = jax.device_put(y, p_sharding)
x1 = jnp.arange(0, 1024, dtype=jnp.float32)
y1 = jnp.arange(0, 1024, dtype=jnp.float32)
jit_fn = jax.jit(
test_fn,
in_shardings=(sharding, p_sharding),
out_shardings=(sharding, p_sharding),
donate_argnums=(0, 1),
)
x_out, y_out = jit_fn(x, y)
self.assertArraysEqual(x_out, x1 * x1)
self.assertArraysEqual(y_out, y1 + y1)
def test_compute_offload_with_linear_layout(self):
# TODO(apaszke): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
self.skipTest("Requires libtpu built after 2024-12-19")
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
p_sharding = jax.sharding.SingleDeviceSharding(
jax.devices()[0], memory_kind="pinned_host"
)
@compute_on("device_host")
@jax.jit
def host_fn(x_in, y_in):
return x_in * x_in, y_in + y_in
def test_fn(x_in, y_in):
x_out, y_out = host_fn(x_in, y_in)
return x_out, y_out
x = jnp.arange(0, 1024, dtype=jnp.float32)
x = jnp.reshape(x, (16, 64))
y = jnp.arange(0, 1024, dtype=jnp.float32)
y = jnp.reshape(y, (16, 64))
custom_dll = DLL(major_to_minor=(0, 1), _tiling=((8, 128),))
custom_dll_linear = DLL(major_to_minor=(0, 1), _tiling=((1,),))
x = jax.device_put(x, Layout(custom_dll, sharding))
y = jax.device_put(y, Layout(custom_dll_linear, p_sharding))
x1 = jnp.arange(0, 1024, dtype=jnp.float32)
x1 = jnp.reshape(x1, (16, 64))
y1 = jnp.arange(0, 1024, dtype=jnp.float32)
y1 = jnp.reshape(y1, (16, 64))
jit_fn = jax.jit(
test_fn,
out_shardings=(
Layout(custom_dll, sharding),
Layout(custom_dll_linear, p_sharding),
),
)
x_out, y_out = jit_fn(x, y)
self.assertArraysEqual(x_out, x1 * x1)
self.assertArraysEqual(y_out, y1 + y1)
def test_compute_offload_mesh_with_linear_layout(self):
mesh = jtu.create_mesh((2, 2), ("x", "y"))
sharding = NamedSharding(mesh, P("x", "y"))
p_sharding = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host")
@compute_on("device_host")
@jax.jit
def host_fn(x_in, y_in):
return x_in * x_in, y_in + y_in
def test_fn(x_in, y_in):
x_out, y_out = host_fn(x_in, y_in)
return x_out, y_out
x = jnp.arange(0, 2048, dtype=jnp.float32)
x = jnp.reshape(x, (32, 64))
y = jnp.arange(0, 2048, dtype=jnp.float32)
y = jnp.reshape(y, (32, 64))
custom_dll = DLL(major_to_minor=(0, 1), _tiling=((8, 128),))
custom_dll_linear = DLL(major_to_minor=(0, 1), _tiling=((1,),))
x = jax.device_put(x, Layout(custom_dll, sharding))
y = jax.device_put(y, Layout(custom_dll_linear, p_sharding))
x1 = jnp.arange(0, 2048, dtype=jnp.float32)
x1 = jnp.reshape(x1, (32, 64))
y1 = jnp.arange(0, 2048, dtype=jnp.float32)
y1 = jnp.reshape(y1, (32, 64))
jit_fn = jax.jit(
test_fn,
out_shardings=(
Layout(custom_dll, sharding),
Layout(custom_dll_linear, p_sharding),
),
)
x_out, y_out = jit_fn(x, y)
self.assertArraysEqual(x_out, x1 * x1)
self.assertArraysEqual(y_out, y1 + y1)
def test_compute_on_cache_miss(self):
@jax.jit
def f(x):
return x * 2
inp = jnp.arange(10)
with jtu.count_jit_tracing_cache_miss() as count:
with compute_on('device_host'):
f(inp)
with compute_on('device'):
f(inp)
# 2 for `f` and `2` for `mul` (compute type changes for `mul`)
self.assertEqual(count(), 4)
def test_offload_take_host(self):
# TODO(apaszke): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
self.skipTest("Requires libtpu built after 2024-12-19")
@compute_on('device_host')
@jax.jit
def peer_forward(x, experts, indices, scores):
w = jnp.take(experts, indices.astype(int), axis=0)
w_gate, w_down, w_up = w[..., 0], w[..., 1], w[..., 2]
g = jnp.einsum('btd, bthkd->bthk', x, w_gate)
x = jnp.einsum('btd, bthkd->bthk', x, w_down)
x = x * jax.nn.gelu(g) * scores
return jnp.einsum('bthk, bthkd->btd', x, w_up)
x = jnp.ones((16, 4, 32))
experts = jnp.ones((128, 32, 3))
indices = jnp.ones((16, 4, 4, 2), dtype=jnp.int32)
scores = jnp.ones((16, 4, 4, 2))
jax.jit(peer_forward)(x, experts, indices, scores) # doesn't crash
class StreamAnnotationTest(jtu.JaxTestCase):
def test_stream_annotation_inside_shmap(self):
if not jtu.test_device_matches(["gpu"]):
self.skipTest("Stream annotation is only supported on GPU.")
mesh = jtu.create_mesh((2,), ('x',))
s = NamedSharding(mesh, P('x'))
np_inp = np.ones((8,))
arr1 = jax.device_put(np_inp, s)
arr2 = jax.device_put(np_inp, s)
@compute_on('gpu_stream:1')
@jax.jit
def g(x, y):
return x * y
@compute_on('gpu_stream:2')
@jax.jit
def h(x, y):
return x * y
def f(x, y):
z = g(x, y)
w = h(3 * x, 2 * y)
return z + w
out = jax.jit(shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')),
out_specs=P('x')))(arr1, arr2)
self.assertArraysEqual(out, arr1 * 7)
class ActivationOffloadingTest(jtu.JaxTestCase):
def setUp(self):
if not jtu.test_device_matches(["tpu", "gpu"]):
self.skipTest("Memories do not work on CPU backend.")
super().setUp()
def test_remat_jaxpr_offloadable(self):
mesh = jtu.create_mesh((2,), ("x",))
inp = jax.device_put(np.arange(16.), NamedSharding(mesh, P("x")))
def policy(prim, *avals, **params):
return Offloadable(src="device", dst="pinned_host")
@functools.partial(remat, policy=policy)
def f(x):
x = jnp.sin(x)
x = jnp.sin(x)
x = jnp.sin(x)
return jnp.sum(x)
fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, inp)
self.assertLen(fwd_jaxpr.out_avals, 4) # 1 output, 3 offloaded residuals
fwd_mem_kind_count = str(fwd_jaxpr).count(
"TransferToMemoryKind(memory_kind='pinned_host')")
self.assertEqual(fwd_mem_kind_count, 3)
self.assertLen(bwd_jaxpr.in_avals, 4) # 3 offloaded residuals, 1 input
bwd_mem_kind_count = str(bwd_jaxpr).count(
"TransferToMemoryKind(memory_kind='device')")
self.assertEqual(bwd_mem_kind_count, 3)
# Execution test.
f = jax.jit(jax.grad(f))
f(inp) # doesn't crash
compiled_f = f.lower(inp).compile()
compiled_text = compiled_f.as_text()
if compiled_text is not None:
self.assertIn('S(5)', compiled_text)
self.assertRegex(compiled_text, r"copy-start.*S\(5\)")
self.assertRegex(compiled_text, r"copy-done.*S\(5\)")
compiled_stats = compiled_f.memory_analysis()
if compiled_stats is not None:
if jtu.pjrt_c_api_version_at_least(0, 43):
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
def test_remat_scan_jaxpr_offloadable(self):
mesh = jtu.create_mesh((2,), ("x",))
shape = (256, 128)
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
s = NamedSharding(mesh, P("x"))
inp = jax.device_put(np_inp, s)
with self.assertRaisesRegex(
ValueError, "The names should be exclusive and should not intersect"):
jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["y", "w"],
offload_src="device", offload_dst="pinned_host")
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z", "w"],
offload_src='device', offload_dst='pinned_host')
@functools.partial(remat, policy=policy)
def f(x):
def g(ys, _):
y, _ = ys
y = checkpoint_name(jnp.sin(y), "y")
z = checkpoint_name(jnp.sin(y), "z")
z = jax.lax.with_sharding_constraint(z, s)
w = checkpoint_name(jnp.sin(z), "w")
return (w, jnp.sum(w)), None
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
return scan_out
fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, inp)
self.assertLen(fwd_jaxpr.out_avals, 5) # 2 output, 3 offloaded residuals
fwd_mem_kind_count = str(fwd_jaxpr).count(
"TransferToMemoryKind(memory_kind='pinned_host')")
self.assertEqual(fwd_mem_kind_count, 2)
self.assertLen(bwd_jaxpr.in_avals, 5) # 3 offloaded residuals, 2 input
bwd_mem_kind_count = str(bwd_jaxpr).count(
"TransferToMemoryKind(memory_kind='device')")
self.assertEqual(bwd_mem_kind_count, 2)
f = jax.jit(jax.grad(f))
f(inp) # doesn't crash
compiled_f = f.lower(inp).compile()
compiled_text = compiled_f.as_text()
if compiled_text is not None:
self.assertIn('S(5)', compiled_text)
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")
compiled_stats = compiled_f.memory_analysis()
if compiled_stats is not None:
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
def test_remat_scan_layout_change_offloadable(self):
mesh = jtu.create_mesh((2,), ("x",))
shape = (256, 128)
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
s = NamedSharding(mesh, P("x"))
inp = jax.device_put(np_inp, s)
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z", "w"],
offload_src='device', offload_dst='pinned_host')
@functools.partial(remat, policy=policy)
def f(x):
def g(ys, _):
y, _ = ys
y = checkpoint_name(jnp.sin(y), "y")
z = checkpoint_name(jnp.sin(y), "z")
z = jax.lax.with_sharding_constraint(z, s)
z = z.T
w = checkpoint_name(jnp.sin(z), "w")
return (w.T, jnp.sum(w)), None
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
return scan_out
f = jax.jit(jax.grad(f))
f(inp) # doesn't crash
compiled_f = f.lower(inp).compile()
compiled_text = compiled_f.as_text()
if compiled_text is not None:
self.assertIn('S(5)', compiled_text)
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")
self.assertRegex(compiled_text, r"dynamic-update-slice-start.*S\(5\)")
self.assertRegex(compiled_text, r"dynamic-update-slice-done.*S\(5\)")
self.assertRegex(compiled_text, r"dynamic-slice-start.*S\(5\)")
self.assertIn("dynamic-slice-start", compiled_text)
compiled_stats = compiled_f.memory_analysis()
if compiled_stats is not None:
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
def test_remat_checkpoint_dots_with_no_batch_dims(self):
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
"device", "pinned_host")
@functools.partial(new_checkpoint, policy=policy)
def f(x):
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
x = jnp.sum(x)
return x
inp = jnp.ones((2, 2))
f = jax.jit(jax.grad(f))
f(inp) # doesn't crash
compiled_f = f.lower(inp).compile()
compiled_text = compiled_f.as_text()
if compiled_text is not None:
self.assertIn('S(5)', compiled_text)
self.assertRegex(compiled_text, r"copy-start.*S\(5\)")
self.assertRegex(compiled_text, r"copy-done.*S\(5\)")
compiled_stats = compiled_f.memory_analysis()
if compiled_stats is not None:
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
def test_primitive_with_multiple_outputs(self):
# Test for https://github.com/jax-ml/jax/issues/25841
shape = (128,)
inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
def policy(prim, *args, **kwargs):
del args, kwargs
if prim.multiple_results:
return Offloadable("device", "pinned_host")
return Recompute
@functools.partial(remat, policy=policy)
def test_fn(x):
# Need any primitive with multiple outputs and a non-trivial grad.
x1, _ = jax.lax.approx_max_k(x, k=2)
return jnp.sum(x1)
fn = jax.grad(test_fn)
jax.jit(fn)(inp) # doesn't crash
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())