mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00

Imported from GitHub PR https://github.com/openxla/xla/pull/22800 Operand shape in long hlo text adds redundant information, which shouldn't be required. Changing the default value to off. The large constants were also printed earlier by default print options, and it is required for parsability and reproducibility. Turning this on by default. This is still controlled by debug option and the default value of that flag disables the large constants, and that behavior is not changed. Just the default print options change here. Copybara import of the project: -- e30dea20489b3fb4d03d373fec0391d69486f4aa by Shraiysh Vaishay <svaishay@nvidia.com>: Change the default value of print_operand_shape_ to false and print_large_constants_ to true. Operand shape in long hlo text adds redundant information, which shouldn't be required. Changing the default value to off. The large constants were also printed earlier by default print options, and it is required for parsability and reproducibility. Turning this on by default. This is still controlled by debug option and the default value of that flag disables the large constants, and that behavior is not changed. Just the default print options change here. -- 7008af0dd0ce342ecbe9475f1d0e277319f1705a by Shraiysh Vaishay <svaishay@nvidia.com>: Handle tests -- b22d5f95cfb7e15f930a2198279a76c38593cc53 by Shraiysh Vaishay <svaishay@nvidia.com>: Fix more tests -- d51579cae7359c6426a87ad4a7ff1b4b0c80f74a by Shraiysh Vaishay <svaishay@nvidia.com>: Fix more tests Merging this change closes #22800 PiperOrigin-RevId: 735690598
1897 lines
63 KiB
Python
1897 lines
63 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
import functools
|
|
import math
|
|
import re
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
from absl import flags
|
|
import unittest
|
|
|
|
import jax
|
|
from jax import lax
|
|
from jax._src import test_util as jtu
|
|
from jax._src import xla_bridge as xb
|
|
from jax._src.layout import DeviceLocalLayout as DLL, Layout
|
|
from jax._src import config
|
|
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
|
import jax.numpy as jnp
|
|
from jax.ad_checkpoint import Offloadable, remat, Recompute
|
|
from jax._src.sharding import common_devices_indices_map
|
|
from jax._src.sharding_impls import (NamedSharding, PositionalSharding,
|
|
SingleDeviceSharding, GSPMDSharding,
|
|
TransferToMemoryKind, PartitionSpec as P)
|
|
from jax.experimental.compute_on import compute_on
|
|
from jax.experimental.shard_map import shard_map
|
|
import numpy as np
|
|
|
|
config.parse_flags_with_absl()
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
def get_memory_kinds_from_executable(f, args):
|
|
compiled = f.lower(*args).compile()
|
|
return compiled.runtime_executable().get_output_memory_kinds()[0]
|
|
|
|
|
|
def _create_inputs(shape, pspec, mem_kind=None):
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
s = NamedSharding(mesh, pspec, memory_kind=mem_kind)
|
|
inp = jax.device_put(np_inp, s)
|
|
return mesh, s, np_inp, inp
|
|
|
|
|
|
class ShardingMemoriesTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
if jtu.test_device_matches(["cpu"]):
|
|
self._default_memory_kind = "unpinned_host"
|
|
else:
|
|
self._default_memory_kind = "device"
|
|
|
|
@parameterized.named_parameters(
|
|
("named_sharding", "named_sharding"),
|
|
("positional_sharding", "positional_sharding"),
|
|
("single_device_sharding", "single_device_sharding"),
|
|
("gspmd_sharding", "gspmd_sharding"),
|
|
)
|
|
def test_canonicalize_memory_kind(self, name):
|
|
if name == "named_sharding":
|
|
mesh = jtu.create_mesh((1,), "x")
|
|
ns = NamedSharding(mesh, P("x"))
|
|
self.assertEqual(ns.memory_kind, self._default_memory_kind)
|
|
elif name == "positional_sharding":
|
|
ps = PositionalSharding(jax.devices())
|
|
self.assertEqual(ps.memory_kind, self._default_memory_kind)
|
|
elif name == "single_device_sharding":
|
|
ss = SingleDeviceSharding(jax.devices()[0])
|
|
self.assertEqual(ss.memory_kind, self._default_memory_kind)
|
|
else:
|
|
assert name == "gspmd_sharding"
|
|
gs = GSPMDSharding.get_replicated(jax.devices())
|
|
self.assertEqual(gs.memory_kind, self._default_memory_kind)
|
|
|
|
@parameterized.named_parameters(
|
|
("named_sharding", "named_sharding"),
|
|
("positional_sharding", "positional_sharding"),
|
|
("single_device_sharding", "single_device_sharding"),
|
|
("gspmd_sharding", "gspmd_sharding"),
|
|
)
|
|
def test_wrong_memory_kind(self, name):
|
|
if name == "named_sharding":
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Could not find memory addressable by device.*"
|
|
):
|
|
mesh = jtu.create_mesh((1,), ("x",))
|
|
NamedSharding(mesh, P("x"), memory_kind="hbm")
|
|
elif name == "positional_sharding":
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Could not find memory addressable by device.*"
|
|
):
|
|
PositionalSharding(jax.devices(), memory_kind="gpu_hbm")
|
|
elif name == "single_device_sharding":
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Could not find memory addressable by device.*Device.*"
|
|
" can address the following memory kinds.*",
|
|
):
|
|
SingleDeviceSharding(jax.devices()[0], memory_kind="host")
|
|
else:
|
|
assert name == "gspmd_sharding"
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Could not find memory addressable by device.*"
|
|
):
|
|
GSPMDSharding.get_replicated(jax.devices(), memory_kind="my_host")
|
|
|
|
@parameterized.named_parameters(
|
|
("named_sharding", "named_sharding"),
|
|
("positional_sharding", "positional_sharding"),
|
|
("single_device_sharding", "single_device_sharding"),
|
|
("gspmd_sharding", "gspmd_sharding"),
|
|
)
|
|
def test_correct_tpu_memory_kind(self, name):
|
|
if not jtu.test_device_matches(["tpu"]):
|
|
self.skipTest("TPU memory kind test.")
|
|
|
|
if name == "named_sharding":
|
|
mesh = jtu.create_mesh((1,), ("x",))
|
|
NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
|
|
elif name == "positional_sharding":
|
|
PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind)
|
|
elif name == "single_device_sharding":
|
|
SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host")
|
|
else:
|
|
assert name == "gspmd_sharding"
|
|
GSPMDSharding.get_replicated(jax.devices(), memory_kind="unpinned_host")
|
|
|
|
@parameterized.named_parameters(
|
|
("named_sharding", "named_sharding"),
|
|
("positional_sharding", "positional_sharding"),
|
|
("single_device_sharding", "single_device_sharding"),
|
|
("gspmd_sharding", "gspmd_sharding"),
|
|
)
|
|
def test_sharding_eq(self, name):
|
|
if name == "named_sharding":
|
|
mesh = jtu.create_mesh((1,), ("x",))
|
|
s1 = NamedSharding(mesh, P("x"))
|
|
s2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
|
|
self.assertEqual(s1, s2)
|
|
elif name == "positional_sharding":
|
|
s1 = PositionalSharding(jax.devices())
|
|
s2 = PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind)
|
|
self.assertEqual(s1, s2)
|
|
elif name == "single_device_sharding":
|
|
s1 = SingleDeviceSharding(jax.devices()[0])
|
|
s2 = SingleDeviceSharding(jax.devices()[0], memory_kind=self._default_memory_kind)
|
|
self.assertEqual(s1, s2)
|
|
elif name == "gspmd_sharding":
|
|
s1 = GSPMDSharding.get_replicated(jax.devices())
|
|
s2 = GSPMDSharding.get_replicated(jax.devices(), memory_kind=self._default_memory_kind)
|
|
self.assertEqual(s1, s2)
|
|
|
|
def test_sharding_equivalent(self):
|
|
mesh = jtu.create_mesh((1,), ("x",))
|
|
ndim = 2
|
|
ns1 = NamedSharding(mesh, P("x"))
|
|
gs1 = GSPMDSharding(
|
|
tuple(mesh.devices.flat),
|
|
ns1._to_xla_hlo_sharding(ndim),
|
|
memory_kind=self._default_memory_kind,
|
|
)
|
|
self.assertTrue(ns1.is_equivalent_to(gs1, ndim))
|
|
|
|
ns2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
|
|
gs2 = GSPMDSharding(
|
|
tuple(mesh.devices.flat), ns2._to_xla_hlo_sharding(ndim)
|
|
)
|
|
self.assertTrue(ns2.is_equivalent_to(gs2, ndim))
|
|
|
|
def test_default_memory_kind(self):
|
|
dev = jax.devices()[0]
|
|
self.assertEqual(dev.default_memory().kind, self._default_memory_kind)
|
|
|
|
|
|
class DevicePutTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
if not jtu.test_device_matches(["tpu", "gpu"]):
|
|
self.skipTest("Memories do not work on CPU backend yet.")
|
|
super().setUp()
|
|
|
|
def _check_device_put_addressable_shards(
|
|
self, out, inp, expected_sharding, expected_mem_kind, index=True):
|
|
self.assertArraysEqual(out, inp)
|
|
self.assertEqual(out.sharding, expected_sharding)
|
|
self.assertEqual(out.sharding.memory_kind, expected_mem_kind)
|
|
for s in out.addressable_shards:
|
|
if index:
|
|
self.assertArraysEqual(s.data, inp[s.index])
|
|
else:
|
|
self.assertArraysEqual(s.data, inp)
|
|
self.assertEqual(s.data.sharding.memory_kind, expected_mem_kind)
|
|
|
|
def test_error_transfer_to_memory_kind_outside_jit(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"TransferToMemoryKind argument to jax.device_put can only be used"
|
|
" inside jax.jit"):
|
|
jax.device_put(np.arange(16), TransferToMemoryKind("device"))
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_host_to_hbm(self, host_memory_kind: str):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind)
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
out_on_host = jax.device_put(np_inp, s_host)
|
|
self.assertEqual(out_on_host.sharding, s_host)
|
|
|
|
s_hbm = s_host.with_memory_kind("device")
|
|
out_on_hbm = jax.device_put(out_on_host, s_hbm)
|
|
self._check_device_put_addressable_shards(
|
|
out_on_hbm, np_inp, s_hbm, "device")
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_hbm_to_host(self, host_memory_kind: str):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind)
|
|
inp = jnp.arange(16).reshape(8, 2)
|
|
|
|
out_on_host = jax.device_put(inp, s_host)
|
|
self._check_device_put_addressable_shards(
|
|
out_on_host, inp, s_host, host_memory_kind)
|
|
|
|
sharded_inp = jax.device_put(inp, s_host.with_memory_kind("device"))
|
|
sharded_out_on_host = jax.device_put(sharded_inp, s_host)
|
|
self._check_device_put_addressable_shards(
|
|
sharded_out_on_host, sharded_inp, s_host, host_memory_kind)
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_different_device_and_memory_host_to_hbm(
|
|
self, host_memory_kind: str
|
|
):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
if jax.device_count() < 3:
|
|
raise unittest.SkipTest("Test requires >=3 devices")
|
|
|
|
out_host0 = jax.device_put(
|
|
jnp.arange(8),
|
|
SingleDeviceSharding(jax.devices()[0], memory_kind=host_memory_kind))
|
|
|
|
dev2 = jax.devices()[2]
|
|
out_hbm1 = jax.device_put(
|
|
out_host0, SingleDeviceSharding(dev2, memory_kind="device"))
|
|
self.assertEqual(out_hbm1.sharding.memory_kind, "device")
|
|
self.assertEqual(out_hbm1.sharding._device, dev2)
|
|
self.assertEqual(out_hbm1.addressable_shards[0].data.sharding._device, dev2)
|
|
self.assertEqual(
|
|
out_hbm1.addressable_shards[0].data.sharding.memory_kind, "device")
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_different_device_and_memory_hbm_to_host(
|
|
self, host_memory_kind: str
|
|
):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
if jax.device_count() < 3:
|
|
raise unittest.SkipTest("Test requires >=3 devices")
|
|
|
|
out_hbm0 = jnp.arange(8)
|
|
|
|
dev2 = jax.devices()[2]
|
|
out_host1 = jax.device_put(
|
|
out_hbm0, SingleDeviceSharding(dev2, memory_kind=host_memory_kind))
|
|
self.assertEqual(out_host1.sharding.memory_kind, host_memory_kind)
|
|
self.assertEqual(out_host1.sharding._device, dev2)
|
|
self.assertEqual(out_host1.addressable_shards[0].data.sharding._device,
|
|
dev2)
|
|
self.assertEqual(
|
|
out_host1.addressable_shards[0].data.sharding.memory_kind,
|
|
host_memory_kind)
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_on_different_device_with_the_same_memory_kind(
|
|
self, host_memory_kind: str):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
if len(jax.devices()) < 2:
|
|
raise unittest.SkipTest("Test requires >=2 devices.")
|
|
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
s_hbm_dev_0 = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
|
|
s_hbm_dev_1 = SingleDeviceSharding(jax.devices()[1], memory_kind="device")
|
|
inp_hbm_dev0 = jax.device_put(np_inp, s_hbm_dev_0)
|
|
out_hbm_dev_1 = jax.device_put(inp_hbm_dev0, s_hbm_dev_1)
|
|
self._check_device_put_addressable_shards(
|
|
out_hbm_dev_1, np_inp, s_hbm_dev_1, "device")
|
|
|
|
inp_host_dev0 = jax.device_put(
|
|
np_inp, s_hbm_dev_0.with_memory_kind(host_memory_kind))
|
|
s_host_dev_1 = s_hbm_dev_1.with_memory_kind(host_memory_kind)
|
|
out_host_dev_1 = jax.device_put(inp_host_dev0, s_host_dev_1)
|
|
self._check_device_put_addressable_shards(
|
|
out_host_dev_1, np_inp, s_host_dev_1, host_memory_kind)
|
|
|
|
# TODO(yashkatariya): Enable this once we can compute on host.
|
|
# def test_device_put_resharding(self):
|
|
# mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
# s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host")
|
|
# s_hbm = s_host.with_memory_kind("device")
|
|
# np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
# # Reshard single device array on HBM to multi device on host
|
|
# sds_inp_hbm = jax.device_put(
|
|
# jnp.arange(16).reshape(8, 2),
|
|
# SingleDeviceSharding(jax.devices()[0], memory_kind="device"))
|
|
# # device_put on host
|
|
# out_sharded_host = jax.device_put(sds_inp_hbm, s_host)
|
|
# self._check_device_put_addressable_shards(
|
|
# out_sharded_host, np_inp, s_host, "unpinned_host")
|
|
|
|
# # Reshard single device array on host to multi device on hbm
|
|
# sds_inp_host = jax.device_put(
|
|
# jnp.arange(16).reshape(8, 2),
|
|
# SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host"))
|
|
# # device_put on hbm
|
|
# out_sharded_hbm = jax.device_put(sds_inp_host, s_hbm)
|
|
# self._check_device_put_addressable_shards(
|
|
# out_sharded_hbm, np_inp, s_hbm, "device")
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_numpy_array(self, host_memory_kind: str):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
s_hbm = NamedSharding(mesh, P(("x", "y")), memory_kind="device")
|
|
s_host = s_hbm.with_memory_kind(host_memory_kind)
|
|
|
|
out_hbm = jax.device_put(np_inp, s_hbm)
|
|
self._check_device_put_addressable_shards(out_hbm, np_inp, s_hbm, "device")
|
|
|
|
out_host = jax.device_put(np_inp, s_host)
|
|
self._check_device_put_addressable_shards(
|
|
out_host, np_inp, s_host, host_memory_kind)
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_numpy_scalar(self, host_memory_kind: str):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
np_inp = np.float32(8)
|
|
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
|
|
s_host = s_hbm.with_memory_kind(host_memory_kind)
|
|
|
|
out_hbm = jax.device_put(np_inp, s_hbm)
|
|
self._check_device_put_addressable_shards(out_hbm, np_inp, s_hbm, "device")
|
|
|
|
out_host = jax.device_put(np_inp, s_host)
|
|
self._check_device_put_addressable_shards(
|
|
out_host, np_inp, s_host, host_memory_kind)
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_python_scalar(self, host_memory_kind: str):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
py_scalar = float(8)
|
|
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
|
|
s_host = s_hbm.with_memory_kind(host_memory_kind)
|
|
|
|
out_hbm = jax.device_put(py_scalar, s_hbm)
|
|
self._check_device_put_addressable_shards(
|
|
out_hbm, py_scalar, s_hbm, "device", index=False)
|
|
|
|
out_host = jax.device_put(py_scalar, s_host)
|
|
self._check_device_put_addressable_shards(
|
|
out_host, py_scalar, s_host, host_memory_kind, index=False)
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_python_int(self, host_memory_kind: str):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
py_inp = 8
|
|
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
|
|
s_host = s_hbm.with_memory_kind(host_memory_kind)
|
|
|
|
out_hbm = jax.device_put(py_inp, s_hbm)
|
|
self._check_device_put_addressable_shards(
|
|
out_hbm, py_inp, s_hbm, "device", index=False)
|
|
|
|
out_host = jax.device_put(py_inp, s_host)
|
|
self._check_device_put_addressable_shards(
|
|
out_host, py_inp, s_host, host_memory_kind, index=False)
|
|
|
|
def test_device_put_inside_jit(self):
|
|
_, s_host, np_inp, inp_host = _create_inputs(
|
|
(8, 2), P("x", "y"), mem_kind="pinned_host")
|
|
s_dev = s_host.with_memory_kind("device")
|
|
|
|
@jax.jit
|
|
def f(a, b):
|
|
x, y = jax.device_put((a, b), s_dev)
|
|
return x * y
|
|
|
|
out = f(inp_host, inp_host)
|
|
self._check_device_put_addressable_shards(
|
|
out, np_inp * np_inp, s_dev, "device")
|
|
|
|
def test_parameter_streaming(self):
|
|
_, s_host, np_inp, inp_host = _create_inputs(
|
|
(8, 2), P("x", "y"), mem_kind="pinned_host")
|
|
s_dev = s_host.with_memory_kind('device')
|
|
inp_dev = jax.device_put(np_inp, s_dev)
|
|
|
|
@functools.partial(jax.jit, out_shardings=s_host)
|
|
def f(a, b):
|
|
x = b * 2
|
|
y = jax.device_put(a, s_dev)
|
|
z = x * y
|
|
return z * 4, z
|
|
|
|
compiled = f.lower(inp_host, inp_dev).compile() # doesn't crash
|
|
compiled_text = compiled.as_text()
|
|
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")
|
|
|
|
out1, out2 = f(inp_host, inp_dev)
|
|
self._check_device_put_addressable_shards(
|
|
out1, np_inp * np_inp * 8, s_host, 'pinned_host')
|
|
self._check_device_put_addressable_shards(
|
|
out2, np_inp * np_inp * 2, s_host, 'pinned_host')
|
|
|
|
def test_zero_size_parameter(self):
|
|
if jtu.test_device_matches(["gpu"]):
|
|
self.skipTest("This test does not work on GPU backend.")
|
|
_, s_host, np_inp, inp_host = _create_inputs(
|
|
(0,), P(), mem_kind="pinned_host")
|
|
s_dev = s_host.with_memory_kind('device')
|
|
|
|
@functools.partial(jax.jit, out_shardings=s_host)
|
|
def f(a):
|
|
b = jax.device_put(a, s_dev)
|
|
return b
|
|
|
|
compiled = f.lower(inp_host).compile() # doesn't crash
|
|
compiled_text = compiled.as_text()
|
|
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")
|
|
|
|
out = f(inp_host)
|
|
self._check_device_put_addressable_shards(
|
|
out, np_inp, s_host, 'pinned_host')
|
|
|
|
def test_parameter_streaming_with_scalar_and_constant(self):
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
scalar_inp = 1
|
|
s_host = NamedSharding(mesh, P(), memory_kind="pinned_host")
|
|
|
|
@functools.partial(jax.jit, out_shardings=s_host)
|
|
def f(scalar_input):
|
|
y = jax.device_put(scalar_input, s_host)
|
|
z = 2
|
|
w = jax.device_put(z, s_host)
|
|
return y, w
|
|
|
|
compiled = f.lower(scalar_inp).compile() # doesn't crash
|
|
compiled_text = compiled.as_text()
|
|
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")
|
|
|
|
out1, out2 = f(scalar_inp)
|
|
self._check_device_put_addressable_shards(
|
|
out1, scalar_inp, s_host, "pinned_host", index=False
|
|
)
|
|
self._check_device_put_addressable_shards(
|
|
out2, 2, s_host, "pinned_host", index=False
|
|
)
|
|
|
|
def test_parameter_and_output_streaming_with_array(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
|
self.skipTest("This test requires an xla_version >= 2.")
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host")
|
|
inp_host = jax.device_put(np_inp, s_host)
|
|
|
|
@functools.partial(jax.jit, out_shardings=(s_host, s_host))
|
|
def f(x):
|
|
return (x, x)
|
|
|
|
compiled = f.lower(inp_host).compile() # doesn't crash
|
|
compiled_text = compiled.as_text()
|
|
if compiled_text is not None:
|
|
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")
|
|
|
|
out1, out2 = f(inp_host)
|
|
self._check_device_put_addressable_shards(
|
|
out1, np_inp, s_host, "pinned_host"
|
|
)
|
|
self._check_device_put_addressable_shards(
|
|
out2, np_inp, s_host, "pinned_host"
|
|
)
|
|
|
|
def test_parameter_and_output_streaming_with_scalar(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
|
self.skipTest("This test requires an xla_version >= 2.")
|
|
|
|
mesh = jax.sharding.Mesh(jax.devices(), "axis")
|
|
s_host = jax.sharding.NamedSharding(
|
|
mesh, jax.sharding.PartitionSpec(), memory_kind="pinned_host"
|
|
)
|
|
scalar_inp = 1
|
|
|
|
@functools.partial(jax.jit, out_shardings=(s_host, s_host))
|
|
def f(x):
|
|
return (x, x)
|
|
|
|
compiled = f.lower(scalar_inp).compile() # doesn't crash
|
|
compiled_text = compiled.as_text()
|
|
if compiled_text is not None:
|
|
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")
|
|
|
|
out1, out2 = f(scalar_inp)
|
|
self._check_device_put_addressable_shards(
|
|
out1, scalar_inp, s_host, "pinned_host", index=False
|
|
)
|
|
self._check_device_put_addressable_shards(
|
|
out2, scalar_inp, s_host, "pinned_host", index=False
|
|
)
|
|
|
|
def test_identity_jit_host_to_device_and_vice_versa(self):
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
s_host = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')
|
|
s_dev = s_host.with_memory_kind('device')
|
|
arr_host = jax.device_put(np_inp, s_host)
|
|
arr_dev = jax.device_put(np_inp, s_dev)
|
|
|
|
# pinned_host -> device
|
|
f = jax.jit(lambda x: x, out_shardings=s_dev)
|
|
out_dev = f(arr_host)
|
|
self.assertArraysEqual(out_dev, np_inp)
|
|
self.assertEqual(out_dev.sharding, s_dev)
|
|
|
|
# device -> pinned_host
|
|
g = jax.jit(lambda x: x, out_shardings=s_host)
|
|
out_host = g(arr_dev)
|
|
self.assertArraysEqual(out_host, np_inp)
|
|
self.assertEqual(out_host.sharding, s_host)
|
|
|
|
def test_parameter_streaming_inside_scan(self):
|
|
mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z"))
|
|
np_inp = np.arange(4096.0).reshape(16, 16, 16)
|
|
s_host = NamedSharding(mesh, P("x", "y", "z"), memory_kind="pinned_host")
|
|
arr_host = jax.device_put(np_inp, s_host)
|
|
|
|
@jax.jit
|
|
def f(xs):
|
|
def body(carry, x):
|
|
x_tpu = jax.device_put(x, TransferToMemoryKind("device"))
|
|
return carry, x_tpu + carry
|
|
|
|
return jax.lax.scan(body, 1.0, xs)
|
|
|
|
_, out_hbm = f(arr_host)
|
|
self.assertArraysEqual(out_hbm, np_inp + 1.0)
|
|
# Only expect the last dimension to have a named sharding.
|
|
out_s = NamedSharding(mesh, P(None, None, "z"), memory_kind="device")
|
|
self.assertEqual(out_hbm.sharding, out_s)
|
|
|
|
def test_output_streaming(self):
|
|
mesh = jtu.create_mesh((1, 1), ("x", "y"))
|
|
np_inp = np.arange(16.0).reshape(8, 2)
|
|
s_hbm = NamedSharding(mesh, P("x", "y"), memory_kind="device")
|
|
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host")
|
|
arr_hbm = jax.device_put(np_inp, s_hbm)
|
|
|
|
@functools.partial(jax.jit, out_shardings=s_host)
|
|
def f(xs):
|
|
out_tpu = xs + 1.0
|
|
return out_tpu
|
|
|
|
out_host = f(arr_hbm)
|
|
self.assertArraysEqual(out_host, np_inp + 1.0)
|
|
self.assertEqual(out_host.sharding, s_host)
|
|
|
|
def test_weight_offload_with_dp_on_output(self):
|
|
_, s_dev, np_inp, inp_dev = _create_inputs(
|
|
(8, 2), P("x", "y"), mem_kind="device")
|
|
s_host = s_dev.with_memory_kind('pinned_host')
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
x = x * 2
|
|
y = jax.device_put(x, s_host)
|
|
return y
|
|
|
|
out_host = f(inp_dev)
|
|
self._check_device_put_addressable_shards(
|
|
out_host, np_inp * 2, s_host, 'pinned_host')
|
|
|
|
def test_output_streaming_inside_scan(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
|
self.skipTest("This test requires an xla_version >= 2.")
|
|
mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z"))
|
|
np_inp = np.arange(4096).reshape(16, 16, 16)
|
|
s_hbm = NamedSharding(mesh, P(None, "y", "z"), memory_kind="device")
|
|
arr_hbm = jax.device_put(np_inp, s_hbm)
|
|
|
|
@jax.jit
|
|
def f(xs):
|
|
def body(carry, x):
|
|
out_tpu = x + carry
|
|
return carry, jax.device_put(
|
|
out_tpu, NamedSharding(mesh, P("y", "z"), memory_kind="pinned_host"))
|
|
_, res = jax.lax.scan(body, 1, xs)
|
|
return res
|
|
|
|
out = f(arr_hbm)
|
|
self.assertArraysEqual(out, np_inp + 1)
|
|
self.assertEqual(out.sharding.memory_kind, 'pinned_host')
|
|
|
|
def test_deepcopy(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
|
self.skipTest("This test requires an xla_version >= 2.")
|
|
mesh = jax.sharding.Mesh(jax.devices(), "x")
|
|
s_host = NamedSharding(mesh, P(), memory_kind="pinned_host")
|
|
|
|
t = jax.device_put(jnp.zeros((8, 2)), s_host)
|
|
t_copy = copy.deepcopy(t)
|
|
self.assertArraysEqual(t, t_copy)
|
|
self.assertEqual(t.shape, t_copy.shape)
|
|
|
|
def test_close_over_host_constant_and_stream(self):
|
|
|
|
_, s_host, np_inp, inp_host = _create_inputs(
|
|
(8, 2), P("x", "y"), mem_kind="pinned_host")
|
|
s_dev = s_host.with_memory_kind('device')
|
|
|
|
@functools.partial(jax.jit, out_shardings=s_dev)
|
|
def f():
|
|
y = jax.device_put(inp_host, s_dev)
|
|
z = y * 2
|
|
return z
|
|
|
|
out = f()
|
|
self._check_device_put_addressable_shards(out, np_inp * 2, s_dev, 'device')
|
|
|
|
@jtu.run_on_devices('tpu')
|
|
def test_ragged_copy_on_host(self):
|
|
mesh = jtu.create_mesh((2,), ('x'))
|
|
sharding = jax.sharding.NamedSharding(mesh, P(('x')))
|
|
cpu_sharding = sharding.with_memory_kind('pinned_host')
|
|
|
|
num_pages = 512 * 1024
|
|
page_size = 1024
|
|
|
|
x = jnp.full((num_pages, page_size), 1, dtype=jnp.bfloat16, device=sharding)
|
|
|
|
def write(x):
|
|
return x.at[16 * 1024:].set(0)
|
|
x = shard_map(write, mesh, P(('x'),), P(('x')))(x)
|
|
|
|
chunk_size = 8
|
|
def inner(state):
|
|
idx, x, output = state
|
|
chunk = jax.lax.dynamic_slice_in_dim(x, idx * chunk_size, chunk_size)
|
|
chunk_host = jax.device_put(chunk, TransferToMemoryKind('pinned_host'))
|
|
output = jax.lax.dynamic_update_slice_in_dim(
|
|
output, chunk_host, idx * chunk_size, axis=0)
|
|
return (idx + 1, x, output)
|
|
|
|
def cond(state):
|
|
idx, x, _ = state
|
|
chunk = jax.lax.dynamic_slice_in_dim(x, idx * chunk_size, chunk_size)
|
|
return (idx * chunk_size < x.shape[0]) & jnp.any(chunk > 0)
|
|
|
|
def foo(x):
|
|
output = jnp.zeros_like(x, device=cpu_sharding)
|
|
_, _, cpu_x = jax.lax.while_loop(cond, inner, (0, x, output))
|
|
return cpu_x
|
|
|
|
fn = jax.jit(shard_map(foo, mesh, P(('x'),), P(('x')),
|
|
check_rep=False),
|
|
out_shardings=cpu_sharding)
|
|
y = fn(x)
|
|
jax.block_until_ready(y)
|
|
compiled_text = fn.lower(x).compile().as_text()
|
|
if compiled_text is not None:
|
|
self.assertIn('custom_call_target="AllocateBuffer"', compiled_text)
|
|
|
|
def test_disallow_alias_copies_arrays(self):
|
|
mesh = jtu.create_mesh((2,), ("x",))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
s = NamedSharding(mesh, P("x"), memory_kind="pinned_host")
|
|
inp_host = jax.device_put(np_inp, s)
|
|
|
|
inp_host_copy = jax.device_put(inp_host, may_alias=False)
|
|
|
|
for a in jax.tree.leaves(inp_host):
|
|
a.delete()
|
|
|
|
jax.block_until_ready(inp_host_copy)
|
|
|
|
def test_disallow_alias_copies_arrays_with_donated_input(self):
|
|
mesh = jtu.create_mesh((2,), ("x",))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
s = NamedSharding(mesh, P("x"), memory_kind="pinned_host")
|
|
inp_host = jax.device_put(np_inp, s)
|
|
|
|
inp_host_donate = jax.jit(lambda x: x, donate_argnums=0)(inp_host)
|
|
|
|
inp_host_donate_copy = jax.device_put(inp_host_donate, may_alias=False)
|
|
|
|
for a in jax.tree.leaves(inp_host_donate):
|
|
a.delete()
|
|
|
|
jax.block_until_ready(inp_host_donate_copy)
|
|
|
|
|
|
class ComputeOffload(jtu.BufferDonationTestCase):
|
|
|
|
def setUp(self):
|
|
if not jtu.test_device_matches(["tpu"]):
|
|
self.skipTest("Memories do not work on CPU and GPU backends yet.")
|
|
super().setUp()
|
|
|
|
def _check_mem_kind(self, executable_kind, out_sharding, expected_kind):
|
|
out_kind = out_sharding.memory_kind
|
|
self.assertEqual(executable_kind, out_kind)
|
|
self.assertEqual(out_kind, expected_kind)
|
|
self.assertEqual(executable_kind, expected_kind)
|
|
|
|
def test_compute_no_inputs(self):
|
|
mesh = jtu.create_mesh((4,), ('data'))
|
|
|
|
tpu_sharding = NamedSharding(mesh, P('data'))
|
|
cpu_sharding = NamedSharding(mesh, P('data'), memory_kind='pinned_host')
|
|
|
|
@functools.partial(jax.jit, out_shardings=(tpu_sharding, cpu_sharding))
|
|
def init():
|
|
tpu_array = jax.random.normal(jax.random.key(42), (16,16))
|
|
cpu_array = jax.random.normal(jax.random.key(42), (16,16))
|
|
return tpu_array, cpu_array
|
|
|
|
tpu_array, cpu_array = init()
|
|
self.assertEqual(tpu_array.sharding, tpu_sharding)
|
|
self.assertEqual(cpu_array.sharding, cpu_sharding)
|
|
|
|
def test_compute_no_inputs_host_replicated(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3:
|
|
self.skipTest("This test requires an xla_version >= 3.")
|
|
if config.use_shardy_partitioner.value:
|
|
self.skipTest("XLA failure due to b/370786664 and b/366411266. "
|
|
"Enable when fixed.")
|
|
mesh = jtu.create_mesh((4,), ('data'))
|
|
|
|
tpu_sharding = NamedSharding(mesh, P('data'))
|
|
cpu_sharding = NamedSharding(mesh, P(), memory_kind='pinned_host')
|
|
|
|
@functools.partial(jax.jit, out_shardings=(tpu_sharding, cpu_sharding))
|
|
def init():
|
|
tpu_array = jax.random.normal(jax.random.key(42), (16, 16))
|
|
cpu_array = jax.random.normal(jax.random.key(42), (16, 16))
|
|
return tpu_array, cpu_array
|
|
|
|
tpu_array, cpu_array = init()
|
|
self.assertEqual(tpu_array.sharding, tpu_sharding)
|
|
self.assertEqual(cpu_array.sharding, cpu_sharding)
|
|
|
|
def test_compute_on_basic(self):
|
|
out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host')
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def g(x):
|
|
return x * 2
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
y = g(x)
|
|
return y * 3
|
|
|
|
inp = jnp.arange(8)
|
|
out = f(inp)
|
|
self.assertArraysEqual(out, inp * 6)
|
|
|
|
lowered_text = f.lower(jnp.arange(8)).as_text()
|
|
self.assertIn('_xla_compute_type', lowered_text)
|
|
|
|
@functools.partial(jax.jit, out_shardings=out_s)
|
|
def h(x):
|
|
y = g(x)
|
|
return y * 3
|
|
|
|
out2 = h(inp)
|
|
self.assertArraysEqual(out2, inp * 6)
|
|
self.assertEqual(out2.sharding.memory_kind, 'pinned_host')
|
|
|
|
def test_compute_on_host_shared_sharding(self):
|
|
mesh = jtu.create_mesh((2,), ("x"))
|
|
device_sharding = NamedSharding(mesh, P("x"))
|
|
host_sharding = device_sharding.with_memory_kind("pinned_host")
|
|
|
|
@compute_on("device_host")
|
|
@functools.partial(
|
|
jax.jit,
|
|
in_shardings=(host_sharding, device_sharding),
|
|
out_shardings=(host_sharding, device_sharding),
|
|
donate_argnums=(0, 1),
|
|
)
|
|
def host_func(x, y):
|
|
return (x * y), ((x**2) * (y**2))
|
|
|
|
@functools.partial(
|
|
jax.jit,
|
|
in_shardings=(host_sharding, device_sharding),
|
|
out_shardings=(host_sharding, device_sharding),
|
|
donate_argnums=(0),
|
|
)
|
|
def device_func(host_data, device_data):
|
|
host_data, device_data = host_func(host_data, device_data)
|
|
device_data = device_data * 2
|
|
host_data, device_data = host_func(host_data, device_data)
|
|
return (host_data, device_data)
|
|
|
|
input_x = jnp.ones(8)
|
|
input_host = jax.device_put(input_x, host_sharding)
|
|
|
|
input_device = jnp.arange(8)
|
|
input_device = jnp.where(input_device < 4, 0, 1)
|
|
input_device = jax.device_put(input_device, device_sharding)
|
|
|
|
output_host, output_device = device_func(input_host, input_device)
|
|
self.assertEqual(output_host.sharding.memory_kind, 'pinned_host')
|
|
self.assertEqual(output_device.sharding.memory_kind, 'device')
|
|
self.assertArraysEqual(output_host, [0., 0., 0., 0., 2., 2., 2., 2.])
|
|
self.assertArraysEqual(output_device, [0., 0., 0., 0., 4., 4., 4., 4.])
|
|
|
|
def test_compute_on_basic_inline(self):
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def g(x):
|
|
return x * 2
|
|
|
|
@functools.partial(jax.jit, inline=True)
|
|
def h(x):
|
|
y = g(x)
|
|
return y * 3
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return h(x)
|
|
|
|
inp = jnp.arange(8)
|
|
out = f(inp)
|
|
self.assertArraysEqual(out, inp * 6)
|
|
|
|
lowered_text = f.lower(jnp.arange(8)).as_text('hlo')
|
|
self.assertRegex(lowered_text,
|
|
'to_apply=g.*frontend_attributes={_xla_compute_type="host"}')
|
|
|
|
def test_compute_on_reduction(self):
|
|
out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host')
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def g(x):
|
|
# Reduction generates multiple host computations (inside a single host
|
|
# computation module): the main one and a reduction body.
|
|
return jnp.sum(x)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
y = g(x)
|
|
z = jnp.sum(x)
|
|
return y * z
|
|
|
|
inp = jnp.arange(8)
|
|
out = f(inp)
|
|
self.assertArraysEqual(out, np.sum(inp) * np.sum(inp))
|
|
|
|
lowered_text = f.lower(jnp.arange(8)).as_text()
|
|
self.assertIn('_xla_compute_type', lowered_text)
|
|
|
|
@functools.partial(jax.jit, out_shardings=out_s)
|
|
def h(x):
|
|
y = g(x)
|
|
z = jnp.sum(x)
|
|
return y * z
|
|
|
|
out2 = h(inp)
|
|
self.assertArraysEqual(out2, np.sum(inp) * np.sum(inp))
|
|
self.assertEqual(out2.sharding.memory_kind, 'pinned_host')
|
|
|
|
def test_compute_host_loop(self):
|
|
# TODO(apaszke): Remove after 12 weeks have passed.
|
|
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
|
self.skipTest("Requires libtpu built after 2024-12-19")
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def fn():
|
|
k = jax.random.key(0)
|
|
return jax.nn.initializers.lecun_normal()(k, (2, 2), jnp.float32)
|
|
fn() # doesn't crash
|
|
|
|
@compute_on('device_host')
|
|
def fn():
|
|
k = jax.random.key(0)
|
|
return jax.nn.initializers.lecun_normal()(k, (2, 2), jnp.float32)
|
|
fn() # doesn't crash
|
|
|
|
def test_nested_compute_error(self):
|
|
@compute_on('device')
|
|
@jax.jit
|
|
def f0(x):
|
|
return x * 2
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def f1(x):
|
|
return f0(x)
|
|
|
|
@jax.jit
|
|
def f2(x):
|
|
return f1(x)
|
|
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
"Nesting `compute_on` with different compute types is not supported"
|
|
" yet."):
|
|
f2(jnp.arange(8))
|
|
|
|
def test_compute_on_grad(self):
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def g(x):
|
|
return jnp.sin(x)
|
|
|
|
def f(x):
|
|
y = g(x)
|
|
return jnp.sum(y)
|
|
|
|
inp = jnp.arange(8.)
|
|
jf = jax.jit(jax.grad(f))
|
|
|
|
jtu.check_grads(jf, (inp,), order=2)
|
|
|
|
lowered_text = jf.lower(inp).as_text('hlo')
|
|
out = re.findall(r"call.*to_apply.*_xla_compute_type", lowered_text)
|
|
self.assertLen(out, 2)
|
|
|
|
def test_compute_on_remat(self):
|
|
inp = jnp.arange(16.)
|
|
|
|
def policy(prim, *avals, **params):
|
|
return Recompute
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def g(x):
|
|
x = jnp.sin(x)
|
|
x = jnp.sin(x)
|
|
x = jnp.sin(x)
|
|
return x
|
|
|
|
@functools.partial(remat, policy=policy)
|
|
def f(x):
|
|
x = g(x)
|
|
return jnp.sum(x)
|
|
|
|
# Execution test.
|
|
jf = jax.jit(jax.grad(f))
|
|
jf(inp) # doesn't crash
|
|
|
|
lowered_text = jf.lower(inp).as_text('hlo')
|
|
out = re.findall(r"call.*to_apply.*_xla_compute_type", lowered_text)
|
|
self.assertLen(out, 2)
|
|
|
|
def test_nested_no_op_compute(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def f0(x):
|
|
return x * 2
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def f1(x):
|
|
x = x * 3
|
|
return f0(x)
|
|
|
|
@jax.jit
|
|
def f2(x):
|
|
return f1(x)
|
|
|
|
out = f2(arr)
|
|
self.assertArraysEqual(out, arr * 6)
|
|
self.assertEqual(out.sharding, s)
|
|
|
|
def test_sharded_compute_on_host(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def g(x, y):
|
|
return x * y
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
x = x * 3
|
|
return g(x, x)
|
|
|
|
out = f(arr)
|
|
expected_out = (np_inp * 3) * (np_inp * 3)
|
|
self.assertEqual(out.sharding, s)
|
|
self.assertArraysEqual(out, expected_out)
|
|
|
|
def test_host_offload_in_custom_vjp(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
|
self.skipTest("This test requires an xla_version >= 2.")
|
|
@jax.custom_vjp
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def eq(x, y):
|
|
return (x == y).astype(jnp.float32)
|
|
|
|
def f_fwd(x):
|
|
y = x * 2
|
|
z = jax.device_put(y, TransferToMemoryKind('pinned_host'))
|
|
return y, (x, z)
|
|
|
|
def f_bwd(res, tx):
|
|
x, z = res
|
|
y = x * 2
|
|
z2 = jax.device_put(y, TransferToMemoryKind('pinned_host'))
|
|
return (eq(z, z2),)
|
|
|
|
f.defvjp(f_fwd, f_bwd)
|
|
g = jax.jit(jax.grad(lambda x: f(x).sum()))
|
|
|
|
x = jnp.ones(3) * 4
|
|
all_true = jnp.ones(3)
|
|
self.assertArraysEqual(g(x), all_true)
|
|
|
|
def test_host_offload_in_custom_vjp_sharded(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
|
self.skipTest("This test requires an xla_version >= 2.")
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
s = NamedSharding(mesh, P('x'))
|
|
|
|
@jax.custom_vjp
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def eq(x, y):
|
|
return (x == y).astype(jnp.float32)
|
|
|
|
def f_fwd(x):
|
|
y = x * 2
|
|
z = jax.device_put(y, s.with_memory_kind('pinned_host'))
|
|
return y, (x, z)
|
|
|
|
def f_bwd(res, tx):
|
|
x, z = res
|
|
y = x * 2
|
|
z2 = jax.device_put(y, s.with_memory_kind('pinned_host'))
|
|
return (eq(z, z2),)
|
|
|
|
f.defvjp(f_fwd, f_bwd)
|
|
g = jax.jit(jax.grad(lambda x: f(x).sum()))
|
|
|
|
arr = jax.device_put(jnp.ones(4) * 4, s)
|
|
all_true = jnp.ones(4)
|
|
self.assertArraysEqual(g(arr), all_true)
|
|
|
|
def test_scan_offload(self):
|
|
np_inp = jnp.arange(4096).reshape(16, 16, 16)
|
|
|
|
@jax.jit
|
|
def f(xs):
|
|
def body(carry, x):
|
|
with compute_on('device_host'):
|
|
out_tpu = x + carry
|
|
return carry, out_tpu
|
|
_, res = jax.lax.scan(body, 1, xs)
|
|
return res
|
|
|
|
out = f(np_inp)
|
|
self.assertArraysEqual(out, np_inp + 1)
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def body2(carry, x):
|
|
out_tpu = x + carry
|
|
return carry, out_tpu
|
|
|
|
@jax.jit
|
|
def f2(xs):
|
|
_, res = jax.lax.scan(body2, 1, xs)
|
|
return res
|
|
|
|
out2 = f2(np_inp)
|
|
self.assertArraysEqual(out2, np_inp + 1)
|
|
|
|
@parameterized.parameters(True, False)
|
|
def test_copy_offload(self, jit_compute_fn: bool):
|
|
# test an explicit copy within the host computation.
|
|
|
|
def g(x):
|
|
return jnp.copy(x) * 2
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
if jit_compute_fn:
|
|
y = compute_on("device_host")(jax.jit(g))(x)
|
|
else:
|
|
y = compute_on("device_host")(g)(x)
|
|
return y * 3
|
|
|
|
inp = jnp.arange(8)
|
|
out = f(inp)
|
|
self.assertArraysEqual(out, inp * 6)
|
|
|
|
lowered_text = f.lower(jnp.arange(8)).as_text()
|
|
self.assertIn('_xla_compute_type', lowered_text)
|
|
|
|
def test_pure_host_data_and_compute(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
|
self.skipTest("This test requires an xla_version >= 2.")
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
arr_host = jax.device_put(np_inp, s)
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def g(x):
|
|
return x * x
|
|
|
|
@functools.partial(jax.jit, out_shardings=s)
|
|
def f(x):
|
|
return g(x)
|
|
|
|
out = f(arr_host)
|
|
self.assertEqual(out.sharding, s)
|
|
self.assertEqual(out.sharding.memory_kind, 'pinned_host')
|
|
self.assertArraysEqual(out, np_inp * np_inp)
|
|
|
|
def test_eager_compute(self):
|
|
inp = jnp.arange(8.)
|
|
with compute_on('device_host'):
|
|
out = inp * 2
|
|
out = jnp.sin(out)
|
|
self.assertArraysAllClose(out, jnp.sin(inp * 2))
|
|
|
|
def test_compute_per_annotation(self):
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
s = NamedSharding(mesh, P("x", "y"))
|
|
np_inp = np.arange(16.).reshape(8, 2)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
@jax.jit
|
|
@compute_on('device_host')
|
|
def f(x):
|
|
return jnp.sin(x * 2)
|
|
|
|
# # sharded input
|
|
out = f(arr)
|
|
self.assertArraysAllClose(out, np.sin(np_inp * 2))
|
|
|
|
out2 = f(np_inp)
|
|
self.assertArraysAllClose(out2, np.sin(np_inp * 2))
|
|
|
|
def test_jit_host_multi_outputs(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
|
self.skipTest("This test requires an xla_version >= 2.")
|
|
_, s, np_inp, inp = _create_inputs((8, 2), P("x"))
|
|
|
|
@jax.jit
|
|
def f(x, y):
|
|
x, y = jnp.sin(x), jnp.cos(y)
|
|
x = jax.device_put(x, s.with_memory_kind("pinned_host"))
|
|
y = jax.device_put(y, s.with_memory_kind("device"))
|
|
return x, y
|
|
|
|
out1, out2 = f(inp, inp)
|
|
|
|
self.assertArraysAllClose(out1, np.sin(np_inp))
|
|
self.assertArraysAllClose(out2, np.cos(np_inp))
|
|
self.assertEqual(out1.sharding, s.with_memory_kind("pinned_host"))
|
|
self.assertEqual(out2.sharding, s.with_memory_kind("device"))
|
|
|
|
def test_jit_out_shardings_single_output(self):
|
|
mesh, _, _, inp = _create_inputs((8, 2), P("x", "y"))
|
|
out_s = NamedSharding(mesh, P(), memory_kind="pinned_host")
|
|
|
|
@functools.partial(jax.jit, out_shardings=out_s)
|
|
def g(x):
|
|
return jnp.sum(x * 2)
|
|
|
|
out = g(inp)
|
|
self.assertEqual(out.sharding, out_s)
|
|
executable_mk = get_memory_kinds_from_executable(g, [inp])
|
|
self._check_mem_kind(executable_mk[0], out.sharding, "pinned_host")
|
|
|
|
@jax.jit
|
|
def h(x):
|
|
x = jnp.sum(x * 2)
|
|
out = jax.device_put(x, out_s)
|
|
return out
|
|
|
|
out = h(inp)
|
|
self.assertEqual(out.sharding, out_s)
|
|
executable_mk = get_memory_kinds_from_executable(h, [inp])
|
|
self._check_mem_kind(executable_mk[0], out.sharding, "pinned_host")
|
|
|
|
def test_jit_in_shardings(self):
|
|
_, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
|
|
|
|
@functools.partial(jax.jit, in_shardings=s.with_memory_kind("pinned_host"))
|
|
def f(x):
|
|
return x * 2
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Memory kinds passed to jax.jit does not match memory kind on the"
|
|
" respective arg. Got pjit memory kind: pinned_host, arg memory kind:"
|
|
" device for arg shape.*"):
|
|
f(jnp.arange(16).reshape(8, 2)) # uncommitted inp also raises error
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Memory kinds passed to jax.jit does not match memory kind on the"
|
|
" respective arg. Got pjit memory kind: pinned_host, arg memory kind:"
|
|
" device for arg shape.*"):
|
|
f(inp) # committed inp raises error.
|
|
|
|
@functools.partial(jax.jit, in_shardings=s.with_memory_kind("device"))
|
|
def g(x):
|
|
return x * 2
|
|
|
|
out = g(inp)
|
|
executable_kind = get_memory_kinds_from_executable(g, [inp])
|
|
self.assertArraysEqual(out, np_inp * 2)
|
|
self._check_mem_kind(executable_kind[0], out.sharding, "device")
|
|
|
|
def test_jit_in_out_shardings(self):
|
|
mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"), mem_kind="device")
|
|
out_s = NamedSharding(mesh, P(), memory_kind="device")
|
|
|
|
@functools.partial(jax.jit, in_shardings=s, out_shardings=out_s)
|
|
def f(x):
|
|
return jnp.sum(x)
|
|
|
|
out = f(inp)
|
|
executable_kind = get_memory_kinds_from_executable(f, [inp])
|
|
self.assertArraysEqual(out, np.sum(np_inp))
|
|
self._check_mem_kind(executable_kind[0], out.sharding, "device")
|
|
|
|
@functools.partial(
|
|
jax.jit,
|
|
in_shardings=s,
|
|
out_shardings=out_s.with_memory_kind("pinned_host"),
|
|
)
|
|
def g(x):
|
|
return jnp.sum(x)
|
|
|
|
out = g(inp)
|
|
executable_kind = get_memory_kinds_from_executable(g, [inp])
|
|
self.assertArraysEqual(out, np.sum(np_inp))
|
|
self._check_mem_kind(executable_kind[0], out.sharding, "pinned_host")
|
|
|
|
def test_device_put_different_devices(self):
|
|
_, _, _, inp = _create_inputs((8, 2), P("x", "y"))
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return jax.device_put(
|
|
x, SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host"))
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Received incompatible devices for jitted computation"):
|
|
f(inp)
|
|
|
|
def test_jit_cpp_cache_hit(self):
|
|
mesh, _, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
|
|
inp2 = jax.device_put(
|
|
np_inp, NamedSharding(mesh, P("x", "y"), memory_kind="device"))
|
|
|
|
f = jax.jit(lambda x: x @ x.T)
|
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
|
out = f(inp)
|
|
out2 = f(inp2)
|
|
self.assertEqual(count(), 1)
|
|
|
|
self.assertArraysEqual(out, np_inp @ np_inp.T)
|
|
self.assertArraysEqual(out2, np_inp @ np_inp.T)
|
|
|
|
def test_jit_compilation_cache_hit(self):
|
|
if config.use_shardy_partitioner.value:
|
|
self.skipTest("Shardy doesn't support GSPMDSharding")
|
|
mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
|
|
inp2 = jax.device_put(
|
|
np_inp, GSPMDSharding(tuple(mesh.devices.flat),
|
|
s._to_xla_hlo_sharding(inp.ndim),
|
|
memory_kind="device")
|
|
)
|
|
|
|
f = jax.jit(lambda x: x @ x.T)
|
|
|
|
with (jtu.count_pjit_cpp_cache_miss() as cpp_count,
|
|
jtu.count_jit_and_pmap_lowerings() as lowering_count):
|
|
f(inp)
|
|
f(inp2)
|
|
self.assertEqual(cpp_count(), 2)
|
|
self.assertEqual(lowering_count(), 2)
|
|
|
|
def test_jit_cpp_cache_output_hit(self):
|
|
_, _, _, inp = _create_inputs((8, 2), P("x"), mem_kind="device")
|
|
|
|
@jax.jit
|
|
def mul_two(x):
|
|
return x * 2
|
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
|
out = mul_two(inp)
|
|
mul_two(out)
|
|
self.assertEqual(count(), 1)
|
|
|
|
def test_jit_cache_hit_with_default_and_specified_mem_kind(self):
|
|
_, s, np_inp, _ = _create_inputs((8, 2), P("x", "y"))
|
|
_, s2, np_inp2, _ = _create_inputs((8, 2), P("x", "y"), mem_kind="device")
|
|
|
|
def mul(x):
|
|
return x @ x.T
|
|
|
|
f = jax.jit(mul, in_shardings=s)
|
|
g = jax.jit(mul, in_shardings=s2)
|
|
|
|
with jtu.count_jit_and_pmap_lowerings() as count:
|
|
out = f(np_inp)
|
|
out2 = g(np_inp2)
|
|
self.assertEqual(count(), 1)
|
|
|
|
self.assertArraysEqual(out, np_inp @ np_inp.T)
|
|
self.assertArraysEqual(out2, np_inp2 @ np_inp2.T)
|
|
|
|
def test_sharding_devices_indices_map_cache_hit(self):
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
shape = (8, 2)
|
|
s1 = NamedSharding(mesh, P("x", "y"))
|
|
s2 = NamedSharding(mesh, P("x", "y"), memory_kind="device")
|
|
|
|
s1.devices_indices_map(shape)
|
|
cache_info1 = common_devices_indices_map.cache_info()
|
|
s2.devices_indices_map(shape)
|
|
cache_info2 = common_devices_indices_map.cache_info()
|
|
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
|
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
|
|
|
def test_no_donation_across_memory_kinds(self):
|
|
if xb.using_pjrt_c_api():
|
|
raise unittest.SkipTest("GetOutputShardings not supported in PJRT C API")
|
|
mesh = jtu.create_mesh((2, 1), ("x", "y"))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
s_hbm = NamedSharding(mesh, P("x"))
|
|
s_host = s_hbm.with_memory_kind("pinned_host")
|
|
inp = jax.device_put(np_inp, s_hbm)
|
|
|
|
@functools.partial(jax.jit, out_shardings=s_host, donate_argnums=0)
|
|
def f(x):
|
|
return x * 2
|
|
|
|
with self.assertWarnsRegex(
|
|
UserWarning, "Some donated buffers were not usable"):
|
|
f(inp)
|
|
|
|
lowered_text = f.lower(inp).as_text("hlo")
|
|
self.assertNotIn("input_output_alias", lowered_text)
|
|
self.assertNotDeleted(inp)
|
|
|
|
def test_single_mem_kind_donation_default_mem_kind(self):
|
|
mesh = jtu.create_mesh((2,), "x")
|
|
s = NamedSharding(mesh, P())
|
|
|
|
@functools.partial(jax.jit, out_shardings=s, donate_argnums=0)
|
|
def f(inp1):
|
|
return inp1 * 2
|
|
|
|
x = jax.device_put(np.arange(16).reshape(8, 2), s)
|
|
|
|
f(x)
|
|
|
|
lowered_text = f.lower(x).as_text("hlo")
|
|
self.assertIn("input_output_alias", lowered_text)
|
|
self.assertDeleted(x)
|
|
|
|
def test_compute_offload_inside_shmap(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def g(x):
|
|
return x * 2
|
|
|
|
def f(x):
|
|
x = x * 3
|
|
y = g(x)
|
|
return y * 4
|
|
|
|
out = jax.jit(shard_map(f, mesh=mesh, in_specs=P('x', 'y'),
|
|
out_specs=P('x', 'y')))(arr)
|
|
self.assertArraysEqual(out, np_inp * 24)
|
|
|
|
def test_qr_decomposition_offload(self):
|
|
if jtu.is_cloud_tpu():
|
|
self.skipTest("Test fails on cloud TPU")
|
|
|
|
shape = (3, 3)
|
|
dtype = np.float32
|
|
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
|
|
|
|
@compute_on("device_host")
|
|
@jax.jit
|
|
def g(x):
|
|
return lax.linalg.qr(x, full_matrices=True)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
x, _ = lax.linalg.qr(x, full_matrices=True)
|
|
x, _ = g(x)
|
|
return x
|
|
|
|
out = f(operand) # doesn't crash
|
|
lowered_text = f.lower(operand).as_text()
|
|
self.assertIn('@lapack_sgeqrf', lowered_text)
|
|
self.assertIn('@Qr', lowered_text)
|
|
|
|
@jax.jit
|
|
def h(x):
|
|
x, _ = lax.linalg.qr(x, full_matrices=True)
|
|
x, _ = lax.linalg.qr(x, full_matrices=True)
|
|
return x
|
|
|
|
expected_out = h(operand)
|
|
|
|
self.assertArraysAllClose(out, expected_out, rtol=1e-3)
|
|
|
|
def test_mem_kind_donation_pinned_host(self):
|
|
mesh = jtu.create_mesh((2,), "x")
|
|
s = NamedSharding(mesh, P(), memory_kind='pinned_host')
|
|
s_dev = s.with_memory_kind('device')
|
|
|
|
@compute_on('device_host')
|
|
@functools.partial(jax.jit, out_shardings=(s, s_dev), donate_argnums=(0, 1))
|
|
def f(inp1, inp2):
|
|
return inp1 * 2, inp2 * 2
|
|
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
x = jax.device_put(np_inp, s)
|
|
x_dev = jax.device_put(np_inp, s_dev)
|
|
|
|
f(x, x_dev)
|
|
|
|
lowered_text = f.lower(x, x_dev).as_text("hlo")
|
|
self.assertIn("input_output_alias", lowered_text)
|
|
self.assertDeleted(x)
|
|
self.assertDeleted(x_dev)
|
|
|
|
@parameterized.parameters("pinned_host", "device")
|
|
def test_identity_mem_kind_donation(self, mem_kind):
|
|
mesh = jtu.create_mesh((2,), "x")
|
|
s = NamedSharding(mesh, P(), memory_kind=mem_kind)
|
|
|
|
@functools.partial(jax.jit, out_shardings=s, donate_argnums=0)
|
|
def f(inp):
|
|
return inp
|
|
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
x = jax.device_put(np_inp, s)
|
|
|
|
f(x)
|
|
|
|
lowered_text = f.lower(x).as_text("hlo")
|
|
self.assertIn("input_output_alias", lowered_text)
|
|
self.assertDeleted(x)
|
|
|
|
def test_compute_offload_with_donation(self):
|
|
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
|
p_sharding = jax.sharding.SingleDeviceSharding(
|
|
jax.devices()[0], memory_kind="pinned_host"
|
|
)
|
|
|
|
@compute_on("device_host")
|
|
@jax.jit
|
|
def host_fn(x_in, y_in):
|
|
return x_in * x_in, y_in + y_in
|
|
|
|
def test_fn(x_in, y_in):
|
|
x_out, y_out = host_fn(x_in, y_in)
|
|
return x_out, y_out
|
|
|
|
x = jnp.arange(0, 1024, dtype=jnp.float32)
|
|
y = jnp.arange(0, 1024, dtype=jnp.float32)
|
|
y = jax.device_put(y, p_sharding)
|
|
|
|
x1 = jnp.arange(0, 1024, dtype=jnp.float32)
|
|
y1 = jnp.arange(0, 1024, dtype=jnp.float32)
|
|
|
|
jit_fn = jax.jit(
|
|
test_fn,
|
|
in_shardings=(sharding, p_sharding),
|
|
out_shardings=(sharding, p_sharding),
|
|
donate_argnums=(0, 1),
|
|
)
|
|
x_out, y_out = jit_fn(x, y)
|
|
self.assertArraysEqual(x_out, x1 * x1)
|
|
self.assertArraysEqual(y_out, y1 + y1)
|
|
|
|
def test_compute_offload_with_linear_layout(self):
|
|
# TODO(apaszke): Remove after 12 weeks have passed.
|
|
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
|
self.skipTest("Requires libtpu built after 2024-12-19")
|
|
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
|
p_sharding = jax.sharding.SingleDeviceSharding(
|
|
jax.devices()[0], memory_kind="pinned_host"
|
|
)
|
|
|
|
@compute_on("device_host")
|
|
@jax.jit
|
|
def host_fn(x_in, y_in):
|
|
return x_in * x_in, y_in + y_in
|
|
|
|
def test_fn(x_in, y_in):
|
|
x_out, y_out = host_fn(x_in, y_in)
|
|
return x_out, y_out
|
|
|
|
x = jnp.arange(0, 1024, dtype=jnp.float32)
|
|
x = jnp.reshape(x, (16, 64))
|
|
y = jnp.arange(0, 1024, dtype=jnp.float32)
|
|
y = jnp.reshape(y, (16, 64))
|
|
custom_dll = DLL(major_to_minor=(0, 1), _tiling=((8, 128),))
|
|
custom_dll_linear = DLL(major_to_minor=(0, 1), _tiling=((1,),))
|
|
x = jax.device_put(x, Layout(custom_dll, sharding))
|
|
y = jax.device_put(y, Layout(custom_dll_linear, p_sharding))
|
|
|
|
x1 = jnp.arange(0, 1024, dtype=jnp.float32)
|
|
x1 = jnp.reshape(x1, (16, 64))
|
|
y1 = jnp.arange(0, 1024, dtype=jnp.float32)
|
|
y1 = jnp.reshape(y1, (16, 64))
|
|
|
|
jit_fn = jax.jit(
|
|
test_fn,
|
|
out_shardings=(
|
|
Layout(custom_dll, sharding),
|
|
Layout(custom_dll_linear, p_sharding),
|
|
),
|
|
)
|
|
x_out, y_out = jit_fn(x, y)
|
|
self.assertArraysEqual(x_out, x1 * x1)
|
|
self.assertArraysEqual(y_out, y1 + y1)
|
|
|
|
def test_compute_offload_mesh_with_linear_layout(self):
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
sharding = NamedSharding(mesh, P("x", "y"))
|
|
p_sharding = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host")
|
|
|
|
@compute_on("device_host")
|
|
@jax.jit
|
|
def host_fn(x_in, y_in):
|
|
return x_in * x_in, y_in + y_in
|
|
|
|
def test_fn(x_in, y_in):
|
|
x_out, y_out = host_fn(x_in, y_in)
|
|
return x_out, y_out
|
|
|
|
x = jnp.arange(0, 2048, dtype=jnp.float32)
|
|
x = jnp.reshape(x, (32, 64))
|
|
y = jnp.arange(0, 2048, dtype=jnp.float32)
|
|
y = jnp.reshape(y, (32, 64))
|
|
custom_dll = DLL(major_to_minor=(0, 1), _tiling=((8, 128),))
|
|
custom_dll_linear = DLL(major_to_minor=(0, 1), _tiling=((1,),))
|
|
x = jax.device_put(x, Layout(custom_dll, sharding))
|
|
y = jax.device_put(y, Layout(custom_dll_linear, p_sharding))
|
|
|
|
x1 = jnp.arange(0, 2048, dtype=jnp.float32)
|
|
x1 = jnp.reshape(x1, (32, 64))
|
|
y1 = jnp.arange(0, 2048, dtype=jnp.float32)
|
|
y1 = jnp.reshape(y1, (32, 64))
|
|
|
|
jit_fn = jax.jit(
|
|
test_fn,
|
|
out_shardings=(
|
|
Layout(custom_dll, sharding),
|
|
Layout(custom_dll_linear, p_sharding),
|
|
),
|
|
)
|
|
x_out, y_out = jit_fn(x, y)
|
|
self.assertArraysEqual(x_out, x1 * x1)
|
|
self.assertArraysEqual(y_out, y1 + y1)
|
|
|
|
def test_compute_on_cache_miss(self):
|
|
@jax.jit
|
|
def f(x):
|
|
return x * 2
|
|
|
|
inp = jnp.arange(10)
|
|
with jtu.count_jit_tracing_cache_miss() as count:
|
|
with compute_on('device_host'):
|
|
f(inp)
|
|
|
|
with compute_on('device'):
|
|
f(inp)
|
|
|
|
# 2 for `f` and `2` for `mul` (compute type changes for `mul`)
|
|
self.assertEqual(count(), 4)
|
|
|
|
def test_offload_take_host(self):
|
|
# TODO(apaszke): Remove after 12 weeks have passed.
|
|
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
|
self.skipTest("Requires libtpu built after 2024-12-19")
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def peer_forward(x, experts, indices, scores):
|
|
w = jnp.take(experts, indices.astype(int), axis=0)
|
|
w_gate, w_down, w_up = w[..., 0], w[..., 1], w[..., 2]
|
|
g = jnp.einsum('btd, bthkd->bthk', x, w_gate)
|
|
x = jnp.einsum('btd, bthkd->bthk', x, w_down)
|
|
x = x * jax.nn.gelu(g) * scores
|
|
return jnp.einsum('bthk, bthkd->btd', x, w_up)
|
|
|
|
x = jnp.ones((16, 4, 32))
|
|
experts = jnp.ones((128, 32, 3))
|
|
indices = jnp.ones((16, 4, 4, 2), dtype=jnp.int32)
|
|
scores = jnp.ones((16, 4, 4, 2))
|
|
jax.jit(peer_forward)(x, experts, indices, scores) # doesn't crash
|
|
|
|
|
|
class StreamAnnotationTest(jtu.JaxTestCase):
|
|
|
|
def test_stream_annotation_inside_shmap(self):
|
|
if not jtu.test_device_matches(["gpu"]):
|
|
self.skipTest("Stream annotation is only supported on GPU.")
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
np_inp = np.ones((8, 8))
|
|
arr1 = jax.device_put(np_inp, s)
|
|
arr2 = jax.device_put(np_inp, s)
|
|
|
|
@compute_on('gpu_stream:1')
|
|
@jax.jit
|
|
def g(x, y):
|
|
return x @ y
|
|
|
|
@compute_on('gpu_stream:2')
|
|
@jax.jit
|
|
def h(x, y):
|
|
return x @ y
|
|
|
|
def f(x, y):
|
|
z = g(x, y)
|
|
w = h(3 * x, 2 * y)
|
|
return z + w
|
|
|
|
out = jax.jit(shard_map(f, mesh=mesh,
|
|
in_specs=(P('x', 'y'), P('x', 'y')),
|
|
out_specs=P('x', 'y')))(arr1, arr2)
|
|
self.assertArraysEqual(out, arr1 * 28)
|
|
|
|
|
|
class ActivationOffloadingTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
if not jtu.test_device_matches(["tpu", "gpu"]):
|
|
self.skipTest("Memories do not work on CPU backend.")
|
|
super().setUp()
|
|
|
|
def test_remat_jaxpr_offloadable(self):
|
|
mesh = jtu.create_mesh((2,), ("x",))
|
|
inp = jax.device_put(np.arange(16.), NamedSharding(mesh, P("x")))
|
|
|
|
def policy(prim, *avals, **params):
|
|
return Offloadable(src="device", dst="pinned_host")
|
|
|
|
@functools.partial(remat, policy=policy)
|
|
def f(x):
|
|
x = jnp.sin(x)
|
|
x = jnp.sin(x)
|
|
x = jnp.sin(x)
|
|
return jnp.sum(x)
|
|
|
|
fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, inp)
|
|
|
|
self.assertLen(fwd_jaxpr.out_avals, 4) # 1 output, 3 offloaded residuals
|
|
fwd_mem_kind_count = str(fwd_jaxpr).count(
|
|
"TransferToMemoryKind(memory_kind='pinned_host')")
|
|
self.assertEqual(fwd_mem_kind_count, 3)
|
|
|
|
self.assertLen(bwd_jaxpr.in_avals, 4) # 3 offloaded residuals, 1 input
|
|
bwd_mem_kind_count = str(bwd_jaxpr).count(
|
|
"TransferToMemoryKind(memory_kind='device')")
|
|
self.assertEqual(bwd_mem_kind_count, 3)
|
|
|
|
# Execution test.
|
|
f = jax.jit(jax.grad(f))
|
|
f(inp) # doesn't crash
|
|
|
|
compiled_f = f.lower(inp).compile()
|
|
|
|
compiled_text = compiled_f.as_text()
|
|
if compiled_text is not None:
|
|
self.assertIn('S(5)', compiled_text)
|
|
self.assertRegex(compiled_text, r"copy-start.*S\(5\)")
|
|
self.assertRegex(compiled_text, r"copy-done.*S\(5\)")
|
|
|
|
compiled_stats = compiled_f.memory_analysis()
|
|
if compiled_stats is not None:
|
|
if jtu.pjrt_c_api_version_at_least(0, 43):
|
|
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
|
|
|
|
def test_remat_scan_jaxpr_offloadable(self):
|
|
mesh = jtu.create_mesh((2,), ("x",))
|
|
shape = (256, 128)
|
|
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
s = NamedSharding(mesh, P("x"))
|
|
inp = jax.device_put(np_inp, s)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, "The names should be exclusive and should not intersect"):
|
|
jax.checkpoint_policies.save_and_offload_only_these_names(
|
|
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["y", "w"],
|
|
offload_src="device", offload_dst="pinned_host")
|
|
|
|
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
|
|
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z", "w"],
|
|
offload_src='device', offload_dst='pinned_host')
|
|
|
|
@functools.partial(remat, policy=policy)
|
|
def f(x):
|
|
def g(ys, _):
|
|
y, _ = ys
|
|
y = checkpoint_name(jnp.sin(y), "y")
|
|
z = checkpoint_name(jnp.sin(y), "z")
|
|
z = jax.lax.with_sharding_constraint(z, s)
|
|
w = checkpoint_name(jnp.sin(z), "w")
|
|
return (w, jnp.sum(w)), None
|
|
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
|
|
return scan_out
|
|
|
|
fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, inp)
|
|
|
|
self.assertLen(fwd_jaxpr.out_avals, 5) # 2 output, 3 offloaded residuals
|
|
fwd_mem_kind_count = str(fwd_jaxpr).count(
|
|
"TransferToMemoryKind(memory_kind='pinned_host')")
|
|
self.assertEqual(fwd_mem_kind_count, 2)
|
|
|
|
self.assertLen(bwd_jaxpr.in_avals, 5) # 3 offloaded residuals, 2 input
|
|
bwd_mem_kind_count = str(bwd_jaxpr).count(
|
|
"TransferToMemoryKind(memory_kind='device')")
|
|
self.assertEqual(bwd_mem_kind_count, 2)
|
|
|
|
f = jax.jit(jax.grad(f))
|
|
f(inp) # doesn't crash
|
|
|
|
compiled_f = f.lower(inp).compile()
|
|
|
|
compiled_text = compiled_f.as_text()
|
|
if compiled_text is not None:
|
|
self.assertIn('S(5)', compiled_text)
|
|
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
|
|
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")
|
|
|
|
compiled_stats = compiled_f.memory_analysis()
|
|
if compiled_stats is not None:
|
|
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
|
|
|
|
def test_remat_scan_layout_change_offloadable(self):
|
|
mesh = jtu.create_mesh((2,), ("x",))
|
|
shape = (256, 128)
|
|
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
s = NamedSharding(mesh, P("x"))
|
|
inp = jax.device_put(np_inp, s)
|
|
|
|
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
|
|
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z", "w"],
|
|
offload_src='device', offload_dst='pinned_host')
|
|
|
|
@functools.partial(remat, policy=policy)
|
|
def f(x):
|
|
def g(ys, _):
|
|
y, _ = ys
|
|
y = checkpoint_name(jnp.sin(y), "y")
|
|
z = checkpoint_name(jnp.sin(y), "z")
|
|
z = jax.lax.with_sharding_constraint(z, s)
|
|
z = z.T
|
|
w = checkpoint_name(jnp.sin(z), "w")
|
|
return (w.T, jnp.sum(w)), None
|
|
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
|
|
return scan_out
|
|
|
|
f = jax.jit(jax.grad(f))
|
|
f(inp) # doesn't crash
|
|
|
|
compiled_f = f.lower(inp).compile()
|
|
|
|
compiled_text = compiled_f.as_text()
|
|
if compiled_text is not None:
|
|
self.assertIn('S(5)', compiled_text)
|
|
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
|
|
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")
|
|
self.assertRegex(compiled_text, r"dynamic-update-slice-start.*S\(5\)")
|
|
self.assertRegex(compiled_text, r"dynamic-update-slice-done.*S\(5\)")
|
|
self.assertRegex(compiled_text, r"dynamic-slice-start.*S\(5\)")
|
|
self.assertIn("dynamic-slice-start", compiled_text)
|
|
|
|
compiled_stats = compiled_f.memory_analysis()
|
|
if compiled_stats is not None:
|
|
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
|
|
|
|
def test_remat_checkpoint_dots_with_no_batch_dims(self):
|
|
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
|
|
"device", "pinned_host")
|
|
|
|
@functools.partial(new_checkpoint, policy=policy)
|
|
def f(x):
|
|
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
|
x = jnp.sin(x)
|
|
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
|
x = jnp.sin(x)
|
|
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
|
x = jnp.sin(x)
|
|
x = jnp.sum(x)
|
|
return x
|
|
|
|
inp = jnp.ones((2, 2))
|
|
f = jax.jit(jax.grad(f))
|
|
f(inp) # doesn't crash
|
|
|
|
compiled_f = f.lower(inp).compile()
|
|
|
|
compiled_text = compiled_f.as_text()
|
|
if compiled_text is not None:
|
|
self.assertIn('S(5)', compiled_text)
|
|
self.assertRegex(compiled_text, r"copy-start.*S\(5\)")
|
|
self.assertRegex(compiled_text, r"copy-done.*S\(5\)")
|
|
|
|
compiled_stats = compiled_f.memory_analysis()
|
|
if compiled_stats is not None:
|
|
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
|
|
|
|
def test_primitive_with_multiple_outputs(self):
|
|
# Test for https://github.com/jax-ml/jax/issues/25841
|
|
shape = (128,)
|
|
inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
def policy(prim, *args, **kwargs):
|
|
del args, kwargs
|
|
if prim.multiple_results:
|
|
return Offloadable("device", "pinned_host")
|
|
return Recompute
|
|
|
|
@functools.partial(remat, policy=policy)
|
|
def test_fn(x):
|
|
# Need any primitive with multiple outputs and a non-trivial grad.
|
|
x1, _ = jax.lax.approx_max_k(x, k=2)
|
|
return jnp.sum(x1)
|
|
|
|
fn = jax.grad(test_fn)
|
|
jax.jit(fn)(inp) # doesn't crash
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|