mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
1896 lines
63 KiB
Python
1896 lines
63 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
import functools
|
|
import math
|
|
import re
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
from absl import flags
|
|
import unittest
|
|
|
|
import jax
|
|
from jax import lax
|
|
from jax._src import test_util as jtu
|
|
from jax._src import xla_bridge as xb
|
|
from jax._src.layout import DeviceLocalLayout as DLL, Layout
|
|
from jax._src import config
|
|
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
|
import jax.numpy as jnp
|
|
from jax.ad_checkpoint import Offloadable, remat, Recompute
|
|
from jax._src.sharding import common_devices_indices_map
|
|
from jax._src.sharding_impls import (NamedSharding, PositionalSharding,
|
|
SingleDeviceSharding, GSPMDSharding,
|
|
TransferToMemoryKind, PartitionSpec as P)
|
|
from jax.experimental.compute_on import compute_on
|
|
from jax.experimental.shard_map import shard_map
|
|
import numpy as np
|
|
|
|
config.parse_flags_with_absl()
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
def get_memory_kinds_from_executable(f, args):
|
|
compiled = f.lower(*args).compile()
|
|
return compiled.runtime_executable().get_output_memory_kinds()[0]
|
|
|
|
|
|
def _create_inputs(shape, pspec, mem_kind=None):
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
s = NamedSharding(mesh, pspec, memory_kind=mem_kind)
|
|
inp = jax.device_put(np_inp, s)
|
|
return mesh, s, np_inp, inp
|
|
|
|
|
|
class ShardingMemoriesTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
if jtu.test_device_matches(["cpu"]):
|
|
self._default_memory_kind = "unpinned_host"
|
|
else:
|
|
self._default_memory_kind = "device"
|
|
|
|
@parameterized.named_parameters(
|
|
("named_sharding", "named_sharding"),
|
|
("positional_sharding", "positional_sharding"),
|
|
("single_device_sharding", "single_device_sharding"),
|
|
("gspmd_sharding", "gspmd_sharding"),
|
|
)
|
|
def test_canonicalize_memory_kind(self, name):
|
|
if name == "named_sharding":
|
|
mesh = jtu.create_mesh((1,), "x")
|
|
ns = NamedSharding(mesh, P("x"))
|
|
self.assertEqual(ns.memory_kind, self._default_memory_kind)
|
|
elif name == "positional_sharding":
|
|
ps = PositionalSharding(jax.devices())
|
|
self.assertEqual(ps.memory_kind, self._default_memory_kind)
|
|
elif name == "single_device_sharding":
|
|
ss = SingleDeviceSharding(jax.devices()[0])
|
|
self.assertEqual(ss.memory_kind, self._default_memory_kind)
|
|
else:
|
|
assert name == "gspmd_sharding"
|
|
gs = GSPMDSharding.get_replicated(jax.devices())
|
|
self.assertEqual(gs.memory_kind, self._default_memory_kind)
|
|
|
|
@parameterized.named_parameters(
|
|
("named_sharding", "named_sharding"),
|
|
("positional_sharding", "positional_sharding"),
|
|
("single_device_sharding", "single_device_sharding"),
|
|
("gspmd_sharding", "gspmd_sharding"),
|
|
)
|
|
def test_wrong_memory_kind(self, name):
|
|
if name == "named_sharding":
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Could not find memory addressable by device.*"
|
|
):
|
|
mesh = jtu.create_mesh((1,), ("x",))
|
|
NamedSharding(mesh, P("x"), memory_kind="hbm")
|
|
elif name == "positional_sharding":
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Could not find memory addressable by device.*"
|
|
):
|
|
PositionalSharding(jax.devices(), memory_kind="gpu_hbm")
|
|
elif name == "single_device_sharding":
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Could not find memory addressable by device.*Device.*"
|
|
" can address the following memory kinds.*",
|
|
):
|
|
SingleDeviceSharding(jax.devices()[0], memory_kind="host")
|
|
else:
|
|
assert name == "gspmd_sharding"
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Could not find memory addressable by device.*"
|
|
):
|
|
GSPMDSharding.get_replicated(jax.devices(), memory_kind="my_host")
|
|
|
|
@parameterized.named_parameters(
|
|
("named_sharding", "named_sharding"),
|
|
("positional_sharding", "positional_sharding"),
|
|
("single_device_sharding", "single_device_sharding"),
|
|
("gspmd_sharding", "gspmd_sharding"),
|
|
)
|
|
def test_correct_tpu_memory_kind(self, name):
|
|
if not jtu.test_device_matches(["tpu"]):
|
|
self.skipTest("TPU memory kind test.")
|
|
|
|
if name == "named_sharding":
|
|
mesh = jtu.create_mesh((1,), ("x",))
|
|
NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
|
|
elif name == "positional_sharding":
|
|
PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind)
|
|
elif name == "single_device_sharding":
|
|
SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host")
|
|
else:
|
|
assert name == "gspmd_sharding"
|
|
GSPMDSharding.get_replicated(jax.devices(), memory_kind="unpinned_host")
|
|
|
|
@parameterized.named_parameters(
|
|
("named_sharding", "named_sharding"),
|
|
("positional_sharding", "positional_sharding"),
|
|
("single_device_sharding", "single_device_sharding"),
|
|
("gspmd_sharding", "gspmd_sharding"),
|
|
)
|
|
def test_sharding_eq(self, name):
|
|
if name == "named_sharding":
|
|
mesh = jtu.create_mesh((1,), ("x",))
|
|
s1 = NamedSharding(mesh, P("x"))
|
|
s2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
|
|
self.assertEqual(s1, s2)
|
|
elif name == "positional_sharding":
|
|
s1 = PositionalSharding(jax.devices())
|
|
s2 = PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind)
|
|
self.assertEqual(s1, s2)
|
|
elif name == "single_device_sharding":
|
|
s1 = SingleDeviceSharding(jax.devices()[0])
|
|
s2 = SingleDeviceSharding(jax.devices()[0], memory_kind=self._default_memory_kind)
|
|
self.assertEqual(s1, s2)
|
|
elif name == "gspmd_sharding":
|
|
s1 = GSPMDSharding.get_replicated(jax.devices())
|
|
s2 = GSPMDSharding.get_replicated(jax.devices(), memory_kind=self._default_memory_kind)
|
|
self.assertEqual(s1, s2)
|
|
|
|
def test_sharding_equivalent(self):
|
|
mesh = jtu.create_mesh((1,), ("x",))
|
|
ndim = 2
|
|
ns1 = NamedSharding(mesh, P("x"))
|
|
gs1 = GSPMDSharding(
|
|
tuple(mesh.devices.flat),
|
|
ns1._to_xla_hlo_sharding(ndim),
|
|
memory_kind=self._default_memory_kind,
|
|
)
|
|
self.assertTrue(ns1.is_equivalent_to(gs1, ndim))
|
|
|
|
ns2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
|
|
gs2 = GSPMDSharding(
|
|
tuple(mesh.devices.flat), ns2._to_xla_hlo_sharding(ndim)
|
|
)
|
|
self.assertTrue(ns2.is_equivalent_to(gs2, ndim))
|
|
|
|
def test_default_memory_kind(self):
|
|
dev = jax.devices()[0]
|
|
self.assertEqual(dev.default_memory().kind, self._default_memory_kind)
|
|
|
|
|
|
class DevicePutTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
if not jtu.test_device_matches(["tpu", "gpu"]):
|
|
self.skipTest("Memories do not work on CPU backend yet.")
|
|
super().setUp()
|
|
|
|
def _check_device_put_addressable_shards(
|
|
self, out, inp, expected_sharding, expected_mem_kind, index=True):
|
|
self.assertArraysEqual(out, inp)
|
|
self.assertEqual(out.sharding, expected_sharding)
|
|
self.assertEqual(out.sharding.memory_kind, expected_mem_kind)
|
|
for s in out.addressable_shards:
|
|
if index:
|
|
self.assertArraysEqual(s.data, inp[s.index])
|
|
else:
|
|
self.assertArraysEqual(s.data, inp)
|
|
self.assertEqual(s.data.sharding.memory_kind, expected_mem_kind)
|
|
|
|
def test_error_transfer_to_memory_kind_outside_jit(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"TransferToMemoryKind argument to jax.device_put can only be used"
|
|
" inside jax.jit"):
|
|
jax.device_put(np.arange(16), TransferToMemoryKind("device"))
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_host_to_hbm(self, host_memory_kind: str):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind)
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
out_on_host = jax.device_put(np_inp, s_host)
|
|
self.assertEqual(out_on_host.sharding, s_host)
|
|
|
|
s_hbm = s_host.with_memory_kind("device")
|
|
out_on_hbm = jax.device_put(out_on_host, s_hbm)
|
|
self._check_device_put_addressable_shards(
|
|
out_on_hbm, np_inp, s_hbm, "device")
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_hbm_to_host(self, host_memory_kind: str):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind)
|
|
inp = jnp.arange(16).reshape(8, 2)
|
|
|
|
out_on_host = jax.device_put(inp, s_host)
|
|
self._check_device_put_addressable_shards(
|
|
out_on_host, inp, s_host, host_memory_kind)
|
|
|
|
sharded_inp = jax.device_put(inp, s_host.with_memory_kind("device"))
|
|
sharded_out_on_host = jax.device_put(sharded_inp, s_host)
|
|
self._check_device_put_addressable_shards(
|
|
sharded_out_on_host, sharded_inp, s_host, host_memory_kind)
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_different_device_and_memory_host_to_hbm(
|
|
self, host_memory_kind: str
|
|
):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
if jax.device_count() < 3:
|
|
raise unittest.SkipTest("Test requires >=3 devices")
|
|
|
|
out_host0 = jax.device_put(
|
|
jnp.arange(8),
|
|
SingleDeviceSharding(jax.devices()[0], memory_kind=host_memory_kind))
|
|
|
|
dev2 = jax.devices()[2]
|
|
out_hbm1 = jax.device_put(
|
|
out_host0, SingleDeviceSharding(dev2, memory_kind="device"))
|
|
self.assertEqual(out_hbm1.sharding.memory_kind, "device")
|
|
self.assertEqual(out_hbm1.sharding._device, dev2)
|
|
self.assertEqual(out_hbm1.addressable_shards[0].data.sharding._device, dev2)
|
|
self.assertEqual(
|
|
out_hbm1.addressable_shards[0].data.sharding.memory_kind, "device")
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_different_device_and_memory_hbm_to_host(
|
|
self, host_memory_kind: str
|
|
):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
if jax.device_count() < 3:
|
|
raise unittest.SkipTest("Test requires >=3 devices")
|
|
|
|
out_hbm0 = jnp.arange(8)
|
|
|
|
dev2 = jax.devices()[2]
|
|
out_host1 = jax.device_put(
|
|
out_hbm0, SingleDeviceSharding(dev2, memory_kind=host_memory_kind))
|
|
self.assertEqual(out_host1.sharding.memory_kind, host_memory_kind)
|
|
self.assertEqual(out_host1.sharding._device, dev2)
|
|
self.assertEqual(out_host1.addressable_shards[0].data.sharding._device,
|
|
dev2)
|
|
self.assertEqual(
|
|
out_host1.addressable_shards[0].data.sharding.memory_kind,
|
|
host_memory_kind)
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_on_different_device_with_the_same_memory_kind(
|
|
self, host_memory_kind: str):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
if len(jax.devices()) < 2:
|
|
raise unittest.SkipTest("Test requires >=2 devices.")
|
|
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
s_hbm_dev_0 = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
|
|
s_hbm_dev_1 = SingleDeviceSharding(jax.devices()[1], memory_kind="device")
|
|
inp_hbm_dev0 = jax.device_put(np_inp, s_hbm_dev_0)
|
|
out_hbm_dev_1 = jax.device_put(inp_hbm_dev0, s_hbm_dev_1)
|
|
self._check_device_put_addressable_shards(
|
|
out_hbm_dev_1, np_inp, s_hbm_dev_1, "device")
|
|
|
|
inp_host_dev0 = jax.device_put(
|
|
np_inp, s_hbm_dev_0.with_memory_kind(host_memory_kind))
|
|
s_host_dev_1 = s_hbm_dev_1.with_memory_kind(host_memory_kind)
|
|
out_host_dev_1 = jax.device_put(inp_host_dev0, s_host_dev_1)
|
|
self._check_device_put_addressable_shards(
|
|
out_host_dev_1, np_inp, s_host_dev_1, host_memory_kind)
|
|
|
|
# TODO(yashkatariya): Enable this once we can compute on host.
|
|
# def test_device_put_resharding(self):
|
|
# mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
# s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host")
|
|
# s_hbm = s_host.with_memory_kind("device")
|
|
# np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
# # Reshard single device array on HBM to multi device on host
|
|
# sds_inp_hbm = jax.device_put(
|
|
# jnp.arange(16).reshape(8, 2),
|
|
# SingleDeviceSharding(jax.devices()[0], memory_kind="device"))
|
|
# # device_put on host
|
|
# out_sharded_host = jax.device_put(sds_inp_hbm, s_host)
|
|
# self._check_device_put_addressable_shards(
|
|
# out_sharded_host, np_inp, s_host, "unpinned_host")
|
|
|
|
# # Reshard single device array on host to multi device on hbm
|
|
# sds_inp_host = jax.device_put(
|
|
# jnp.arange(16).reshape(8, 2),
|
|
# SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host"))
|
|
# # device_put on hbm
|
|
# out_sharded_hbm = jax.device_put(sds_inp_host, s_hbm)
|
|
# self._check_device_put_addressable_shards(
|
|
# out_sharded_hbm, np_inp, s_hbm, "device")
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_numpy_array(self, host_memory_kind: str):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
s_hbm = NamedSharding(mesh, P(("x", "y")), memory_kind="device")
|
|
s_host = s_hbm.with_memory_kind(host_memory_kind)
|
|
|
|
out_hbm = jax.device_put(np_inp, s_hbm)
|
|
self._check_device_put_addressable_shards(out_hbm, np_inp, s_hbm, "device")
|
|
|
|
out_host = jax.device_put(np_inp, s_host)
|
|
self._check_device_put_addressable_shards(
|
|
out_host, np_inp, s_host, host_memory_kind)
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_numpy_scalar(self, host_memory_kind: str):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
np_inp = np.float32(8)
|
|
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
|
|
s_host = s_hbm.with_memory_kind(host_memory_kind)
|
|
|
|
out_hbm = jax.device_put(np_inp, s_hbm)
|
|
self._check_device_put_addressable_shards(out_hbm, np_inp, s_hbm, "device")
|
|
|
|
out_host = jax.device_put(np_inp, s_host)
|
|
self._check_device_put_addressable_shards(
|
|
out_host, np_inp, s_host, host_memory_kind)
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_python_scalar(self, host_memory_kind: str):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
py_scalar = float(8)
|
|
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
|
|
s_host = s_hbm.with_memory_kind(host_memory_kind)
|
|
|
|
out_hbm = jax.device_put(py_scalar, s_hbm)
|
|
self._check_device_put_addressable_shards(
|
|
out_hbm, py_scalar, s_hbm, "device", index=False)
|
|
|
|
out_host = jax.device_put(py_scalar, s_host)
|
|
self._check_device_put_addressable_shards(
|
|
out_host, py_scalar, s_host, host_memory_kind, index=False)
|
|
|
|
@parameterized.parameters("unpinned_host", "pinned_host")
|
|
def test_device_put_python_int(self, host_memory_kind: str):
|
|
if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host":
|
|
self.skipTest("unpinned_host does not work on GPU backend.")
|
|
py_inp = 8
|
|
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
|
|
s_host = s_hbm.with_memory_kind(host_memory_kind)
|
|
|
|
out_hbm = jax.device_put(py_inp, s_hbm)
|
|
self._check_device_put_addressable_shards(
|
|
out_hbm, py_inp, s_hbm, "device", index=False)
|
|
|
|
out_host = jax.device_put(py_inp, s_host)
|
|
self._check_device_put_addressable_shards(
|
|
out_host, py_inp, s_host, host_memory_kind, index=False)
|
|
|
|
def test_device_put_inside_jit(self):
|
|
_, s_host, np_inp, inp_host = _create_inputs(
|
|
(8, 2), P("x", "y"), mem_kind="pinned_host")
|
|
s_dev = s_host.with_memory_kind("device")
|
|
|
|
@jax.jit
|
|
def f(a, b):
|
|
x, y = jax.device_put((a, b), s_dev)
|
|
return x * y
|
|
|
|
out = f(inp_host, inp_host)
|
|
self._check_device_put_addressable_shards(
|
|
out, np_inp * np_inp, s_dev, "device")
|
|
|
|
def test_parameter_streaming(self):
|
|
_, s_host, np_inp, inp_host = _create_inputs(
|
|
(8, 2), P("x", "y"), mem_kind="pinned_host")
|
|
s_dev = s_host.with_memory_kind('device')
|
|
inp_dev = jax.device_put(np_inp, s_dev)
|
|
|
|
@functools.partial(jax.jit, out_shardings=s_host)
|
|
def f(a, b):
|
|
x = b * 2
|
|
y = jax.device_put(a, s_dev)
|
|
z = x * y
|
|
return z * 4, z
|
|
|
|
compiled = f.lower(inp_host, inp_dev).compile() # doesn't crash
|
|
compiled_text = compiled.as_text()
|
|
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")
|
|
|
|
out1, out2 = f(inp_host, inp_dev)
|
|
self._check_device_put_addressable_shards(
|
|
out1, np_inp * np_inp * 8, s_host, 'pinned_host')
|
|
self._check_device_put_addressable_shards(
|
|
out2, np_inp * np_inp * 2, s_host, 'pinned_host')
|
|
|
|
def test_zero_size_parameter(self):
|
|
if jtu.test_device_matches(["gpu"]):
|
|
self.skipTest("This test does not work on GPU backend.")
|
|
_, s_host, np_inp, inp_host = _create_inputs(
|
|
(0,), P(), mem_kind="pinned_host")
|
|
s_dev = s_host.with_memory_kind('device')
|
|
|
|
@functools.partial(jax.jit, out_shardings=s_host)
|
|
def f(a):
|
|
b = jax.device_put(a, s_dev)
|
|
return b
|
|
|
|
compiled = f.lower(inp_host).compile() # doesn't crash
|
|
compiled_text = compiled.as_text()
|
|
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")
|
|
|
|
out = f(inp_host)
|
|
self._check_device_put_addressable_shards(
|
|
out, np_inp, s_host, 'pinned_host')
|
|
|
|
def test_parameter_streaming_with_scalar_and_constant(self):
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
scalar_inp = 1
|
|
s_host = NamedSharding(mesh, P(), memory_kind="pinned_host")
|
|
|
|
@functools.partial(jax.jit, out_shardings=s_host)
|
|
def f(scalar_input):
|
|
y = jax.device_put(scalar_input, s_host)
|
|
z = 2
|
|
w = jax.device_put(z, s_host)
|
|
return y, w
|
|
|
|
compiled = f.lower(scalar_inp).compile() # doesn't crash
|
|
compiled_text = compiled.as_text()
|
|
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")
|
|
|
|
out1, out2 = f(scalar_inp)
|
|
self._check_device_put_addressable_shards(
|
|
out1, scalar_inp, s_host, "pinned_host", index=False
|
|
)
|
|
self._check_device_put_addressable_shards(
|
|
out2, 2, s_host, "pinned_host", index=False
|
|
)
|
|
|
|
def test_parameter_and_output_streaming_with_array(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
|
self.skipTest("This test requires an xla_version >= 2.")
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host")
|
|
inp_host = jax.device_put(np_inp, s_host)
|
|
|
|
@functools.partial(jax.jit, out_shardings=(s_host, s_host))
|
|
def f(x):
|
|
return (x, x)
|
|
|
|
compiled = f.lower(inp_host).compile() # doesn't crash
|
|
compiled_text = compiled.as_text()
|
|
if compiled_text is not None:
|
|
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")
|
|
|
|
out1, out2 = f(inp_host)
|
|
self._check_device_put_addressable_shards(
|
|
out1, np_inp, s_host, "pinned_host"
|
|
)
|
|
self._check_device_put_addressable_shards(
|
|
out2, np_inp, s_host, "pinned_host"
|
|
)
|
|
|
|
def test_parameter_and_output_streaming_with_scalar(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
|
self.skipTest("This test requires an xla_version >= 2.")
|
|
|
|
mesh = jax.sharding.Mesh(jax.devices(), "axis")
|
|
s_host = jax.sharding.NamedSharding(
|
|
mesh, jax.sharding.PartitionSpec(), memory_kind="pinned_host"
|
|
)
|
|
scalar_inp = 1
|
|
|
|
@functools.partial(jax.jit, out_shardings=(s_host, s_host))
|
|
def f(x):
|
|
return (x, x)
|
|
|
|
compiled = f.lower(scalar_inp).compile() # doesn't crash
|
|
compiled_text = compiled.as_text()
|
|
if compiled_text is not None:
|
|
self.assertRegex(compiled_text, r"entry_computation_layout=.*S\(5\)}")
|
|
|
|
out1, out2 = f(scalar_inp)
|
|
self._check_device_put_addressable_shards(
|
|
out1, scalar_inp, s_host, "pinned_host", index=False
|
|
)
|
|
self._check_device_put_addressable_shards(
|
|
out2, scalar_inp, s_host, "pinned_host", index=False
|
|
)
|
|
|
|
def test_identity_jit_host_to_device_and_vice_versa(self):
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
s_host = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')
|
|
s_dev = s_host.with_memory_kind('device')
|
|
arr_host = jax.device_put(np_inp, s_host)
|
|
arr_dev = jax.device_put(np_inp, s_dev)
|
|
|
|
# pinned_host -> device
|
|
f = jax.jit(lambda x: x, out_shardings=s_dev)
|
|
out_dev = f(arr_host)
|
|
self.assertArraysEqual(out_dev, np_inp)
|
|
self.assertEqual(out_dev.sharding, s_dev)
|
|
|
|
# device -> pinned_host
|
|
g = jax.jit(lambda x: x, out_shardings=s_host)
|
|
out_host = g(arr_dev)
|
|
self.assertArraysEqual(out_host, np_inp)
|
|
self.assertEqual(out_host.sharding, s_host)
|
|
|
|
def test_parameter_streaming_inside_scan(self):
|
|
mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z"))
|
|
np_inp = np.arange(4096.0).reshape(16, 16, 16)
|
|
s_host = NamedSharding(mesh, P("x", "y", "z"), memory_kind="pinned_host")
|
|
arr_host = jax.device_put(np_inp, s_host)
|
|
|
|
@jax.jit
|
|
def f(xs):
|
|
def body(carry, x):
|
|
x_tpu = jax.device_put(x, TransferToMemoryKind("device"))
|
|
return carry, x_tpu + carry
|
|
|
|
return jax.lax.scan(body, 1.0, xs)
|
|
|
|
_, out_hbm = f(arr_host)
|
|
self.assertArraysEqual(out_hbm, np_inp + 1.0)
|
|
# Only expect the last dimension to have a named sharding.
|
|
out_s = NamedSharding(mesh, P(None, None, "z"), memory_kind="device")
|
|
self.assertEqual(out_hbm.sharding, out_s)
|
|
|
|
def test_output_streaming(self):
|
|
mesh = jtu.create_mesh((1, 1), ("x", "y"))
|
|
np_inp = np.arange(16.0).reshape(8, 2)
|
|
s_hbm = NamedSharding(mesh, P("x", "y"), memory_kind="device")
|
|
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host")
|
|
arr_hbm = jax.device_put(np_inp, s_hbm)
|
|
|
|
@functools.partial(jax.jit, out_shardings=s_host)
|
|
def f(xs):
|
|
out_tpu = xs + 1.0
|
|
return out_tpu
|
|
|
|
out_host = f(arr_hbm)
|
|
self.assertArraysEqual(out_host, np_inp + 1.0)
|
|
self.assertEqual(out_host.sharding, s_host)
|
|
|
|
def test_weight_offload_with_dp_on_output(self):
|
|
_, s_dev, np_inp, inp_dev = _create_inputs(
|
|
(8, 2), P("x", "y"), mem_kind="device")
|
|
s_host = s_dev.with_memory_kind('pinned_host')
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
x = x * 2
|
|
y = jax.device_put(x, s_host)
|
|
return y
|
|
|
|
out_host = f(inp_dev)
|
|
self._check_device_put_addressable_shards(
|
|
out_host, np_inp * 2, s_host, 'pinned_host')
|
|
|
|
def test_output_streaming_inside_scan(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
|
self.skipTest("This test requires an xla_version >= 2.")
|
|
mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z"))
|
|
np_inp = np.arange(4096).reshape(16, 16, 16)
|
|
s_hbm = NamedSharding(mesh, P(None, "y", "z"), memory_kind="device")
|
|
arr_hbm = jax.device_put(np_inp, s_hbm)
|
|
|
|
@jax.jit
|
|
def f(xs):
|
|
def body(carry, x):
|
|
out_tpu = x + carry
|
|
return carry, jax.device_put(
|
|
out_tpu, NamedSharding(mesh, P("y", "z"), memory_kind="pinned_host"))
|
|
_, res = jax.lax.scan(body, 1, xs)
|
|
return res
|
|
|
|
out = f(arr_hbm)
|
|
self.assertArraysEqual(out, np_inp + 1)
|
|
self.assertEqual(out.sharding.memory_kind, 'pinned_host')
|
|
|
|
def test_deepcopy(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
|
self.skipTest("This test requires an xla_version >= 2.")
|
|
mesh = jax.sharding.Mesh(jax.devices(), "x")
|
|
s_host = NamedSharding(mesh, P(), memory_kind="pinned_host")
|
|
|
|
t = jax.device_put(jnp.zeros((8, 2)), s_host)
|
|
t_copy = copy.deepcopy(t)
|
|
self.assertArraysEqual(t, t_copy)
|
|
self.assertEqual(t.shape, t_copy.shape)
|
|
|
|
def test_close_over_host_constant_and_stream(self):
|
|
|
|
_, s_host, np_inp, inp_host = _create_inputs(
|
|
(8, 2), P("x", "y"), mem_kind="pinned_host")
|
|
s_dev = s_host.with_memory_kind('device')
|
|
|
|
@functools.partial(jax.jit, out_shardings=s_dev)
|
|
def f():
|
|
y = jax.device_put(inp_host, s_dev)
|
|
z = y * 2
|
|
return z
|
|
|
|
out = f()
|
|
self._check_device_put_addressable_shards(out, np_inp * 2, s_dev, 'device')
|
|
|
|
@jtu.run_on_devices('tpu')
|
|
def test_ragged_copy_on_host(self):
|
|
mesh = jtu.create_mesh((2,), ('x'))
|
|
sharding = jax.sharding.NamedSharding(mesh, P(('x')))
|
|
cpu_sharding = sharding.with_memory_kind('pinned_host')
|
|
|
|
num_pages = 512 * 1024
|
|
page_size = 1024
|
|
|
|
x = jnp.full((num_pages, page_size), 1, dtype=jnp.bfloat16, device=sharding)
|
|
|
|
def write(x):
|
|
return x.at[16 * 1024:].set(0)
|
|
x = shard_map(write, mesh, P(('x'),), P(('x')))(x)
|
|
|
|
chunk_size = 8
|
|
def inner(state):
|
|
idx, x, output = state
|
|
chunk = jax.lax.dynamic_slice_in_dim(x, idx * chunk_size, chunk_size)
|
|
chunk_host = jax.device_put(chunk, TransferToMemoryKind('pinned_host'))
|
|
output = jax.lax.dynamic_update_slice_in_dim(
|
|
output, chunk_host, idx * chunk_size, axis=0)
|
|
return (idx + 1, x, output)
|
|
|
|
def cond(state):
|
|
idx, x, _ = state
|
|
chunk = jax.lax.dynamic_slice_in_dim(x, idx * chunk_size, chunk_size)
|
|
return (idx * chunk_size < x.shape[0]) & jnp.any(chunk > 0)
|
|
|
|
def foo(x):
|
|
output = jnp.zeros_like(x, device=cpu_sharding)
|
|
_, _, cpu_x = jax.lax.while_loop(cond, inner, (0, x, output))
|
|
return cpu_x
|
|
|
|
fn = jax.jit(shard_map(foo, mesh, P(('x'),), P(('x')),
|
|
check_rep=False),
|
|
out_shardings=cpu_sharding)
|
|
y = fn(x)
|
|
jax.block_until_ready(y)
|
|
compiled_text = fn.lower(x).compile().as_text()
|
|
if compiled_text is not None:
|
|
self.assertIn('custom_call_target="AllocateBuffer"', compiled_text)
|
|
|
|
def test_disallow_alias_copies_arrays(self):
|
|
mesh = jtu.create_mesh((2,), ("x",))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
s = NamedSharding(mesh, P("x"), memory_kind="pinned_host")
|
|
inp_host = jax.device_put(np_inp, s)
|
|
|
|
inp_host_copy = jax.device_put(inp_host, may_alias=False)
|
|
|
|
for a in jax.tree.leaves(inp_host):
|
|
a.delete()
|
|
|
|
jax.block_until_ready(inp_host_copy)
|
|
|
|
def test_disallow_alias_copies_arrays_with_donated_input(self):
|
|
mesh = jtu.create_mesh((2,), ("x",))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
s = NamedSharding(mesh, P("x"), memory_kind="pinned_host")
|
|
inp_host = jax.device_put(np_inp, s)
|
|
|
|
inp_host_donate = jax.jit(lambda x: x, donate_argnums=0)(inp_host)
|
|
|
|
inp_host_donate_copy = jax.device_put(inp_host_donate, may_alias=False)
|
|
|
|
for a in jax.tree.leaves(inp_host_donate):
|
|
a.delete()
|
|
|
|
jax.block_until_ready(inp_host_donate_copy)
|
|
|
|
|
|
class ComputeOffload(jtu.BufferDonationTestCase):
|
|
|
|
def setUp(self):
|
|
if not jtu.test_device_matches(["tpu"]):
|
|
self.skipTest("Memories do not work on CPU and GPU backends yet.")
|
|
super().setUp()
|
|
|
|
def _check_mem_kind(self, executable_kind, out_sharding, expected_kind):
|
|
out_kind = out_sharding.memory_kind
|
|
self.assertEqual(executable_kind, out_kind)
|
|
self.assertEqual(out_kind, expected_kind)
|
|
self.assertEqual(executable_kind, expected_kind)
|
|
|
|
def test_compute_no_inputs(self):
|
|
mesh = jtu.create_mesh((4,), ('data'))
|
|
|
|
tpu_sharding = NamedSharding(mesh, P('data'))
|
|
cpu_sharding = NamedSharding(mesh, P('data'), memory_kind='pinned_host')
|
|
|
|
@functools.partial(jax.jit, out_shardings=(tpu_sharding, cpu_sharding))
|
|
def init():
|
|
tpu_array = jax.random.normal(jax.random.key(42), (16,16))
|
|
cpu_array = jax.random.normal(jax.random.key(42), (16,16))
|
|
return tpu_array, cpu_array
|
|
|
|
tpu_array, cpu_array = init()
|
|
self.assertEqual(tpu_array.sharding, tpu_sharding)
|
|
self.assertEqual(cpu_array.sharding, cpu_sharding)
|
|
|
|
def test_compute_no_inputs_host_replicated(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3:
|
|
self.skipTest("This test requires an xla_version >= 3.")
|
|
if config.use_shardy_partitioner.value:
|
|
self.skipTest("XLA failure due to b/370786664 and b/366411266. "
|
|
"Enable when fixed.")
|
|
mesh = jtu.create_mesh((4,), ('data'))
|
|
|
|
tpu_sharding = NamedSharding(mesh, P('data'))
|
|
cpu_sharding = NamedSharding(mesh, P(), memory_kind='pinned_host')
|
|
|
|
@functools.partial(jax.jit, out_shardings=(tpu_sharding, cpu_sharding))
|
|
def init():
|
|
tpu_array = jax.random.normal(jax.random.key(42), (16, 16))
|
|
cpu_array = jax.random.normal(jax.random.key(42), (16, 16))
|
|
return tpu_array, cpu_array
|
|
|
|
tpu_array, cpu_array = init()
|
|
self.assertEqual(tpu_array.sharding, tpu_sharding)
|
|
self.assertEqual(cpu_array.sharding, cpu_sharding)
|
|
|
|
def test_compute_on_basic(self):
|
|
out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host')
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def g(x):
|
|
return x * 2
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
y = g(x)
|
|
return y * 3
|
|
|
|
inp = jnp.arange(8)
|
|
out = f(inp)
|
|
self.assertArraysEqual(out, inp * 6)
|
|
|
|
lowered_text = f.lower(jnp.arange(8)).as_text()
|
|
self.assertIn('_xla_compute_type', lowered_text)
|
|
|
|
@functools.partial(jax.jit, out_shardings=out_s)
|
|
def h(x):
|
|
y = g(x)
|
|
return y * 3
|
|
|
|
out2 = h(inp)
|
|
self.assertArraysEqual(out2, inp * 6)
|
|
self.assertEqual(out2.sharding.memory_kind, 'pinned_host')
|
|
|
|
def test_compute_on_host_shared_sharding(self):
|
|
mesh = jtu.create_mesh((2,), ("x"))
|
|
device_sharding = NamedSharding(mesh, P("x"))
|
|
host_sharding = device_sharding.with_memory_kind("pinned_host")
|
|
|
|
@compute_on("device_host")
|
|
@functools.partial(
|
|
jax.jit,
|
|
in_shardings=(host_sharding, device_sharding),
|
|
out_shardings=(host_sharding, device_sharding),
|
|
donate_argnums=(0, 1),
|
|
)
|
|
def host_func(x, y):
|
|
return (x * y), ((x**2) * (y**2))
|
|
|
|
@functools.partial(
|
|
jax.jit,
|
|
in_shardings=(host_sharding, device_sharding),
|
|
out_shardings=(host_sharding, device_sharding),
|
|
donate_argnums=(0),
|
|
)
|
|
def device_func(host_data, device_data):
|
|
host_data, device_data = host_func(host_data, device_data)
|
|
device_data = device_data * 2
|
|
host_data, device_data = host_func(host_data, device_data)
|
|
return (host_data, device_data)
|
|
|
|
input_x = jnp.ones(8)
|
|
input_host = jax.device_put(input_x, host_sharding)
|
|
|
|
input_device = jnp.arange(8)
|
|
input_device = jnp.where(input_device < 4, 0, 1)
|
|
input_device = jax.device_put(input_device, device_sharding)
|
|
|
|
output_host, output_device = device_func(input_host, input_device)
|
|
self.assertEqual(output_host.sharding.memory_kind, 'pinned_host')
|
|
self.assertEqual(output_device.sharding.memory_kind, 'device')
|
|
self.assertArraysEqual(output_host, [0., 0., 0., 0., 2., 2., 2., 2.])
|
|
self.assertArraysEqual(output_device, [0., 0., 0., 0., 4., 4., 4., 4.])
|
|
|
|
def test_compute_on_basic_inline(self):
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def g(x):
|
|
return x * 2
|
|
|
|
@functools.partial(jax.jit, inline=True)
|
|
def h(x):
|
|
y = g(x)
|
|
return y * 3
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return h(x)
|
|
|
|
inp = jnp.arange(8)
|
|
out = f(inp)
|
|
self.assertArraysEqual(out, inp * 6)
|
|
|
|
lowered_text = f.lower(jnp.arange(8)).as_text('hlo')
|
|
self.assertRegex(lowered_text,
|
|
'to_apply=g.*frontend_attributes={_xla_compute_type="host"}')
|
|
|
|
def test_compute_on_reduction(self):
|
|
out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host')
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def g(x):
|
|
# Reduction generates multiple host computations (inside a single host
|
|
# computation module): the main one and a reduction body.
|
|
return jnp.sum(x)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
y = g(x)
|
|
z = jnp.sum(x)
|
|
return y * z
|
|
|
|
inp = jnp.arange(8)
|
|
out = f(inp)
|
|
self.assertArraysEqual(out, np.sum(inp) * np.sum(inp))
|
|
|
|
lowered_text = f.lower(jnp.arange(8)).as_text()
|
|
self.assertIn('_xla_compute_type', lowered_text)
|
|
|
|
@functools.partial(jax.jit, out_shardings=out_s)
|
|
def h(x):
|
|
y = g(x)
|
|
z = jnp.sum(x)
|
|
return y * z
|
|
|
|
out2 = h(inp)
|
|
self.assertArraysEqual(out2, np.sum(inp) * np.sum(inp))
|
|
self.assertEqual(out2.sharding.memory_kind, 'pinned_host')
|
|
|
|
def test_compute_host_loop(self):
|
|
# TODO(apaszke): Remove after 12 weeks have passed.
|
|
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
|
self.skipTest("Requires libtpu built after 2024-12-19")
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def fn():
|
|
k = jax.random.key(0)
|
|
return jax.nn.initializers.lecun_normal()(k, (2, 2), jnp.float32)
|
|
fn() # doesn't crash
|
|
|
|
@compute_on('device_host')
|
|
def fn():
|
|
k = jax.random.key(0)
|
|
return jax.nn.initializers.lecun_normal()(k, (2, 2), jnp.float32)
|
|
fn() # doesn't crash
|
|
|
|
def test_nested_compute_error(self):
|
|
@compute_on('device')
|
|
@jax.jit
|
|
def f0(x):
|
|
return x * 2
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def f1(x):
|
|
return f0(x)
|
|
|
|
@jax.jit
|
|
def f2(x):
|
|
return f1(x)
|
|
|
|
with self.assertRaisesRegex(
|
|
NotImplementedError,
|
|
"Nesting `compute_on` with different compute types is not supported"
|
|
" yet."):
|
|
f2(jnp.arange(8))
|
|
|
|
def test_compute_on_grad(self):
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def g(x):
|
|
return jnp.sin(x)
|
|
|
|
def f(x):
|
|
y = g(x)
|
|
return jnp.sum(y)
|
|
|
|
inp = jnp.arange(8.)
|
|
jf = jax.jit(jax.grad(f))
|
|
|
|
jtu.check_grads(jf, (inp,), order=2)
|
|
|
|
lowered_text = jf.lower(inp).as_text('hlo')
|
|
out = re.findall(r"call.*to_apply.*_xla_compute_type", lowered_text)
|
|
self.assertLen(out, 2)
|
|
|
|
def test_compute_on_remat(self):
|
|
inp = jnp.arange(16.)
|
|
|
|
def policy(prim, *avals, **params):
|
|
return Recompute
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def g(x):
|
|
x = jnp.sin(x)
|
|
x = jnp.sin(x)
|
|
x = jnp.sin(x)
|
|
return x
|
|
|
|
@functools.partial(remat, policy=policy)
|
|
def f(x):
|
|
x = g(x)
|
|
return jnp.sum(x)
|
|
|
|
# Execution test.
|
|
jf = jax.jit(jax.grad(f))
|
|
jf(inp) # doesn't crash
|
|
|
|
lowered_text = jf.lower(inp).as_text('hlo')
|
|
out = re.findall(r"call.*to_apply.*_xla_compute_type", lowered_text)
|
|
self.assertLen(out, 2)
|
|
|
|
def test_nested_no_op_compute(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def f0(x):
|
|
return x * 2
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def f1(x):
|
|
x = x * 3
|
|
return f0(x)
|
|
|
|
@jax.jit
|
|
def f2(x):
|
|
return f1(x)
|
|
|
|
out = f2(arr)
|
|
self.assertArraysEqual(out, arr * 6)
|
|
self.assertEqual(out.sharding, s)
|
|
|
|
def test_sharded_compute_on_host(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def g(x, y):
|
|
return x * y
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
x = x * 3
|
|
return g(x, x)
|
|
|
|
out = f(arr)
|
|
expected_out = (np_inp * 3) * (np_inp * 3)
|
|
self.assertEqual(out.sharding, s)
|
|
self.assertArraysEqual(out, expected_out)
|
|
|
|
def test_host_offload_in_custom_vjp(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
|
self.skipTest("This test requires an xla_version >= 2.")
|
|
@jax.custom_vjp
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def eq(x, y):
|
|
return (x == y).astype(jnp.float32)
|
|
|
|
def f_fwd(x):
|
|
y = x * 2
|
|
z = jax.device_put(y, TransferToMemoryKind('pinned_host'))
|
|
return y, (x, z)
|
|
|
|
def f_bwd(res, tx):
|
|
x, z = res
|
|
y = x * 2
|
|
z2 = jax.device_put(y, TransferToMemoryKind('pinned_host'))
|
|
return (eq(z, z2),)
|
|
|
|
f.defvjp(f_fwd, f_bwd)
|
|
g = jax.jit(jax.grad(lambda x: f(x).sum()))
|
|
|
|
x = jnp.ones(3) * 4
|
|
all_true = jnp.ones(3)
|
|
self.assertArraysEqual(g(x), all_true)
|
|
|
|
def test_host_offload_in_custom_vjp_sharded(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
|
self.skipTest("This test requires an xla_version >= 2.")
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
s = NamedSharding(mesh, P('x'))
|
|
|
|
@jax.custom_vjp
|
|
def f(x):
|
|
return jnp.sin(x)
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def eq(x, y):
|
|
return (x == y).astype(jnp.float32)
|
|
|
|
def f_fwd(x):
|
|
y = x * 2
|
|
z = jax.device_put(y, s.with_memory_kind('pinned_host'))
|
|
return y, (x, z)
|
|
|
|
def f_bwd(res, tx):
|
|
x, z = res
|
|
y = x * 2
|
|
z2 = jax.device_put(y, s.with_memory_kind('pinned_host'))
|
|
return (eq(z, z2),)
|
|
|
|
f.defvjp(f_fwd, f_bwd)
|
|
g = jax.jit(jax.grad(lambda x: f(x).sum()))
|
|
|
|
arr = jax.device_put(jnp.ones(4) * 4, s)
|
|
all_true = jnp.ones(4)
|
|
self.assertArraysEqual(g(arr), all_true)
|
|
|
|
def test_scan_offload(self):
|
|
np_inp = jnp.arange(4096).reshape(16, 16, 16)
|
|
|
|
@jax.jit
|
|
def f(xs):
|
|
def body(carry, x):
|
|
with compute_on('device_host'):
|
|
out_tpu = x + carry
|
|
return carry, out_tpu
|
|
_, res = jax.lax.scan(body, 1, xs)
|
|
return res
|
|
|
|
out = f(np_inp)
|
|
self.assertArraysEqual(out, np_inp + 1)
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def body2(carry, x):
|
|
out_tpu = x + carry
|
|
return carry, out_tpu
|
|
|
|
@jax.jit
|
|
def f2(xs):
|
|
_, res = jax.lax.scan(body2, 1, xs)
|
|
return res
|
|
|
|
out2 = f2(np_inp)
|
|
self.assertArraysEqual(out2, np_inp + 1)
|
|
|
|
@parameterized.parameters(True, False)
|
|
def test_copy_offload(self, jit_compute_fn: bool):
|
|
# test an explicit copy within the host computation.
|
|
|
|
def g(x):
|
|
return jnp.copy(x) * 2
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
if jit_compute_fn:
|
|
y = compute_on("device_host")(jax.jit(g))(x)
|
|
else:
|
|
y = compute_on("device_host")(g)(x)
|
|
return y * 3
|
|
|
|
inp = jnp.arange(8)
|
|
out = f(inp)
|
|
self.assertArraysEqual(out, inp * 6)
|
|
|
|
lowered_text = f.lower(jnp.arange(8)).as_text()
|
|
self.assertIn('_xla_compute_type', lowered_text)
|
|
|
|
def test_pure_host_data_and_compute(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
|
self.skipTest("This test requires an xla_version >= 2.")
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
arr_host = jax.device_put(np_inp, s)
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def g(x):
|
|
return x * x
|
|
|
|
@functools.partial(jax.jit, out_shardings=s)
|
|
def f(x):
|
|
return g(x)
|
|
|
|
out = f(arr_host)
|
|
self.assertEqual(out.sharding, s)
|
|
self.assertEqual(out.sharding.memory_kind, 'pinned_host')
|
|
self.assertArraysEqual(out, np_inp * np_inp)
|
|
|
|
def test_eager_compute(self):
|
|
inp = jnp.arange(8.)
|
|
with compute_on('device_host'):
|
|
out = inp * 2
|
|
out = jnp.sin(out)
|
|
self.assertArraysAllClose(out, jnp.sin(inp * 2))
|
|
|
|
def test_compute_per_annotation(self):
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
s = NamedSharding(mesh, P("x", "y"))
|
|
np_inp = np.arange(16.).reshape(8, 2)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
@jax.jit
|
|
@compute_on('device_host')
|
|
def f(x):
|
|
return jnp.sin(x * 2)
|
|
|
|
# # sharded input
|
|
out = f(arr)
|
|
self.assertArraysAllClose(out, np.sin(np_inp * 2))
|
|
|
|
out2 = f(np_inp)
|
|
self.assertArraysAllClose(out2, np.sin(np_inp * 2))
|
|
|
|
def test_jit_host_multi_outputs(self):
|
|
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
|
self.skipTest("This test requires an xla_version >= 2.")
|
|
_, s, np_inp, inp = _create_inputs((8, 2), P("x"))
|
|
|
|
@jax.jit
|
|
def f(x, y):
|
|
x, y = jnp.sin(x), jnp.cos(y)
|
|
x = jax.device_put(x, s.with_memory_kind("pinned_host"))
|
|
y = jax.device_put(y, s.with_memory_kind("device"))
|
|
return x, y
|
|
|
|
out1, out2 = f(inp, inp)
|
|
|
|
self.assertArraysAllClose(out1, np.sin(np_inp))
|
|
self.assertArraysAllClose(out2, np.cos(np_inp))
|
|
self.assertEqual(out1.sharding, s.with_memory_kind("pinned_host"))
|
|
self.assertEqual(out2.sharding, s.with_memory_kind("device"))
|
|
|
|
def test_jit_out_shardings_single_output(self):
|
|
mesh, _, _, inp = _create_inputs((8, 2), P("x", "y"))
|
|
out_s = NamedSharding(mesh, P(), memory_kind="pinned_host")
|
|
|
|
@functools.partial(jax.jit, out_shardings=out_s)
|
|
def g(x):
|
|
return jnp.sum(x * 2)
|
|
|
|
out = g(inp)
|
|
self.assertEqual(out.sharding, out_s)
|
|
executable_mk = get_memory_kinds_from_executable(g, [inp])
|
|
self._check_mem_kind(executable_mk[0], out.sharding, "pinned_host")
|
|
|
|
@jax.jit
|
|
def h(x):
|
|
x = jnp.sum(x * 2)
|
|
out = jax.device_put(x, out_s)
|
|
return out
|
|
|
|
out = h(inp)
|
|
self.assertEqual(out.sharding, out_s)
|
|
executable_mk = get_memory_kinds_from_executable(h, [inp])
|
|
self._check_mem_kind(executable_mk[0], out.sharding, "pinned_host")
|
|
|
|
def test_jit_in_shardings(self):
|
|
_, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
|
|
|
|
@functools.partial(jax.jit, in_shardings=s.with_memory_kind("pinned_host"))
|
|
def f(x):
|
|
return x * 2
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Memory kinds passed to jax.jit does not match memory kind on the"
|
|
" respective arg. Got pjit memory kind: pinned_host, arg memory kind:"
|
|
" device for arg shape.*"):
|
|
f(jnp.arange(16).reshape(8, 2)) # uncommitted inp also raises error
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Memory kinds passed to jax.jit does not match memory kind on the"
|
|
" respective arg. Got pjit memory kind: pinned_host, arg memory kind:"
|
|
" device for arg shape.*"):
|
|
f(inp) # committed inp raises error.
|
|
|
|
@functools.partial(jax.jit, in_shardings=s.with_memory_kind("device"))
|
|
def g(x):
|
|
return x * 2
|
|
|
|
out = g(inp)
|
|
executable_kind = get_memory_kinds_from_executable(g, [inp])
|
|
self.assertArraysEqual(out, np_inp * 2)
|
|
self._check_mem_kind(executable_kind[0], out.sharding, "device")
|
|
|
|
def test_jit_in_out_shardings(self):
|
|
mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"), mem_kind="device")
|
|
out_s = NamedSharding(mesh, P(), memory_kind="device")
|
|
|
|
@functools.partial(jax.jit, in_shardings=s, out_shardings=out_s)
|
|
def f(x):
|
|
return jnp.sum(x)
|
|
|
|
out = f(inp)
|
|
executable_kind = get_memory_kinds_from_executable(f, [inp])
|
|
self.assertArraysEqual(out, np.sum(np_inp))
|
|
self._check_mem_kind(executable_kind[0], out.sharding, "device")
|
|
|
|
@functools.partial(
|
|
jax.jit,
|
|
in_shardings=s,
|
|
out_shardings=out_s.with_memory_kind("pinned_host"),
|
|
)
|
|
def g(x):
|
|
return jnp.sum(x)
|
|
|
|
out = g(inp)
|
|
executable_kind = get_memory_kinds_from_executable(g, [inp])
|
|
self.assertArraysEqual(out, np.sum(np_inp))
|
|
self._check_mem_kind(executable_kind[0], out.sharding, "pinned_host")
|
|
|
|
def test_device_put_different_devices(self):
|
|
_, _, _, inp = _create_inputs((8, 2), P("x", "y"))
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return jax.device_put(
|
|
x, SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host"))
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Received incompatible devices for jitted computation"):
|
|
f(inp)
|
|
|
|
def test_jit_cpp_cache_hit(self):
|
|
mesh, _, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
|
|
inp2 = jax.device_put(
|
|
np_inp, NamedSharding(mesh, P("x", "y"), memory_kind="device"))
|
|
|
|
f = jax.jit(lambda x: x @ x.T)
|
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
|
out = f(inp)
|
|
out2 = f(inp2)
|
|
self.assertEqual(count(), 1)
|
|
|
|
self.assertArraysEqual(out, np_inp @ np_inp.T)
|
|
self.assertArraysEqual(out2, np_inp @ np_inp.T)
|
|
|
|
def test_jit_compilation_cache_hit(self):
|
|
if config.use_shardy_partitioner.value:
|
|
self.skipTest("Shardy doesn't support GSPMDSharding")
|
|
mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
|
|
inp2 = jax.device_put(
|
|
np_inp, GSPMDSharding(tuple(mesh.devices.flat),
|
|
s._to_xla_hlo_sharding(inp.ndim),
|
|
memory_kind="device")
|
|
)
|
|
|
|
f = jax.jit(lambda x: x @ x.T)
|
|
|
|
with (jtu.count_pjit_cpp_cache_miss() as cpp_count,
|
|
jtu.count_jit_and_pmap_lowerings() as lowering_count):
|
|
f(inp)
|
|
f(inp2)
|
|
self.assertEqual(cpp_count(), 2)
|
|
self.assertEqual(lowering_count(), 2)
|
|
|
|
def test_jit_cpp_cache_output_hit(self):
|
|
_, _, _, inp = _create_inputs((8, 2), P("x"), mem_kind="device")
|
|
|
|
@jax.jit
|
|
def mul_two(x):
|
|
return x * 2
|
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
|
out = mul_two(inp)
|
|
mul_two(out)
|
|
self.assertEqual(count(), 1)
|
|
|
|
def test_jit_cache_hit_with_default_and_specified_mem_kind(self):
|
|
_, s, np_inp, _ = _create_inputs((8, 2), P("x", "y"))
|
|
_, s2, np_inp2, _ = _create_inputs((8, 2), P("x", "y"), mem_kind="device")
|
|
|
|
def mul(x):
|
|
return x @ x.T
|
|
|
|
f = jax.jit(mul, in_shardings=s)
|
|
g = jax.jit(mul, in_shardings=s2)
|
|
|
|
with jtu.count_jit_and_pmap_lowerings() as count:
|
|
out = f(np_inp)
|
|
out2 = g(np_inp2)
|
|
self.assertEqual(count(), 1)
|
|
|
|
self.assertArraysEqual(out, np_inp @ np_inp.T)
|
|
self.assertArraysEqual(out2, np_inp2 @ np_inp2.T)
|
|
|
|
def test_sharding_devices_indices_map_cache_hit(self):
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
shape = (8, 2)
|
|
s1 = NamedSharding(mesh, P("x", "y"))
|
|
s2 = NamedSharding(mesh, P("x", "y"), memory_kind="device")
|
|
|
|
s1.devices_indices_map(shape)
|
|
cache_info1 = common_devices_indices_map.cache_info()
|
|
s2.devices_indices_map(shape)
|
|
cache_info2 = common_devices_indices_map.cache_info()
|
|
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
|
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
|
|
|
def test_no_donation_across_memory_kinds(self):
|
|
if xb.using_pjrt_c_api():
|
|
raise unittest.SkipTest("GetOutputShardings not supported in PJRT C API")
|
|
mesh = jtu.create_mesh((2, 1), ("x", "y"))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
s_hbm = NamedSharding(mesh, P("x"))
|
|
s_host = s_hbm.with_memory_kind("pinned_host")
|
|
inp = jax.device_put(np_inp, s_hbm)
|
|
|
|
@functools.partial(jax.jit, out_shardings=s_host, donate_argnums=0)
|
|
def f(x):
|
|
return x * 2
|
|
|
|
with self.assertWarnsRegex(
|
|
UserWarning, "Some donated buffers were not usable"):
|
|
f(inp)
|
|
|
|
lowered_text = f.lower(inp).as_text("hlo")
|
|
self.assertNotIn("input_output_alias", lowered_text)
|
|
self.assertNotDeleted(inp)
|
|
|
|
def test_single_mem_kind_donation_default_mem_kind(self):
|
|
mesh = jtu.create_mesh((2,), "x")
|
|
s = NamedSharding(mesh, P())
|
|
|
|
@functools.partial(jax.jit, out_shardings=s, donate_argnums=0)
|
|
def f(inp1):
|
|
return inp1 * 2
|
|
|
|
x = jax.device_put(np.arange(16).reshape(8, 2), s)
|
|
|
|
f(x)
|
|
|
|
lowered_text = f.lower(x).as_text("hlo")
|
|
self.assertIn("input_output_alias", lowered_text)
|
|
self.assertDeleted(x)
|
|
|
|
def test_compute_offload_inside_shmap(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def g(x):
|
|
return x * 2
|
|
|
|
def f(x):
|
|
x = x * 3
|
|
y = g(x)
|
|
return y * 4
|
|
|
|
out = jax.jit(shard_map(f, mesh=mesh, in_specs=P('x', 'y'),
|
|
out_specs=P('x', 'y')))(arr)
|
|
self.assertArraysEqual(out, np_inp * 24)
|
|
|
|
def test_qr_decomposition_offload(self):
|
|
if jtu.is_cloud_tpu():
|
|
self.skipTest("Test fails on cloud TPU")
|
|
|
|
shape = (3, 3)
|
|
dtype = np.float32
|
|
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
|
|
|
|
@compute_on("device_host")
|
|
@jax.jit
|
|
def g(x):
|
|
return lax.linalg.qr(x, full_matrices=True)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
x, _ = lax.linalg.qr(x, full_matrices=True)
|
|
x, _ = g(x)
|
|
return x
|
|
|
|
out = f(operand) # doesn't crash
|
|
lowered_text = f.lower(operand).as_text()
|
|
self.assertIn('@lapack_sgeqrf', lowered_text)
|
|
self.assertIn('@Qr', lowered_text)
|
|
|
|
@jax.jit
|
|
def h(x):
|
|
x, _ = lax.linalg.qr(x, full_matrices=True)
|
|
x, _ = lax.linalg.qr(x, full_matrices=True)
|
|
return x
|
|
|
|
expected_out = h(operand)
|
|
|
|
self.assertArraysAllClose(out, expected_out, rtol=1e-3)
|
|
|
|
def test_mem_kind_donation_pinned_host(self):
|
|
mesh = jtu.create_mesh((2,), "x")
|
|
s = NamedSharding(mesh, P(), memory_kind='pinned_host')
|
|
s_dev = s.with_memory_kind('device')
|
|
|
|
@compute_on('device_host')
|
|
@functools.partial(jax.jit, out_shardings=(s, s_dev), donate_argnums=(0, 1))
|
|
def f(inp1, inp2):
|
|
return inp1 * 2, inp2 * 2
|
|
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
x = jax.device_put(np_inp, s)
|
|
x_dev = jax.device_put(np_inp, s_dev)
|
|
|
|
f(x, x_dev)
|
|
|
|
lowered_text = f.lower(x, x_dev).as_text("hlo")
|
|
self.assertIn("input_output_alias", lowered_text)
|
|
self.assertDeleted(x)
|
|
self.assertDeleted(x_dev)
|
|
|
|
@parameterized.parameters("pinned_host", "device")
|
|
def test_identity_mem_kind_donation(self, mem_kind):
|
|
mesh = jtu.create_mesh((2,), "x")
|
|
s = NamedSharding(mesh, P(), memory_kind=mem_kind)
|
|
|
|
@functools.partial(jax.jit, out_shardings=s, donate_argnums=0)
|
|
def f(inp):
|
|
return inp
|
|
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
x = jax.device_put(np_inp, s)
|
|
|
|
f(x)
|
|
|
|
lowered_text = f.lower(x).as_text("hlo")
|
|
self.assertIn("input_output_alias", lowered_text)
|
|
self.assertDeleted(x)
|
|
|
|
def test_compute_offload_with_donation(self):
|
|
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
|
p_sharding = jax.sharding.SingleDeviceSharding(
|
|
jax.devices()[0], memory_kind="pinned_host"
|
|
)
|
|
|
|
@compute_on("device_host")
|
|
@jax.jit
|
|
def host_fn(x_in, y_in):
|
|
return x_in * x_in, y_in + y_in
|
|
|
|
def test_fn(x_in, y_in):
|
|
x_out, y_out = host_fn(x_in, y_in)
|
|
return x_out, y_out
|
|
|
|
x = jnp.arange(0, 1024, dtype=jnp.float32)
|
|
y = jnp.arange(0, 1024, dtype=jnp.float32)
|
|
y = jax.device_put(y, p_sharding)
|
|
|
|
x1 = jnp.arange(0, 1024, dtype=jnp.float32)
|
|
y1 = jnp.arange(0, 1024, dtype=jnp.float32)
|
|
|
|
jit_fn = jax.jit(
|
|
test_fn,
|
|
in_shardings=(sharding, p_sharding),
|
|
out_shardings=(sharding, p_sharding),
|
|
donate_argnums=(0, 1),
|
|
)
|
|
x_out, y_out = jit_fn(x, y)
|
|
self.assertArraysEqual(x_out, x1 * x1)
|
|
self.assertArraysEqual(y_out, y1 + y1)
|
|
|
|
def test_compute_offload_with_linear_layout(self):
|
|
# TODO(apaszke): Remove after 12 weeks have passed.
|
|
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
|
self.skipTest("Requires libtpu built after 2024-12-19")
|
|
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
|
p_sharding = jax.sharding.SingleDeviceSharding(
|
|
jax.devices()[0], memory_kind="pinned_host"
|
|
)
|
|
|
|
@compute_on("device_host")
|
|
@jax.jit
|
|
def host_fn(x_in, y_in):
|
|
return x_in * x_in, y_in + y_in
|
|
|
|
def test_fn(x_in, y_in):
|
|
x_out, y_out = host_fn(x_in, y_in)
|
|
return x_out, y_out
|
|
|
|
x = jnp.arange(0, 1024, dtype=jnp.float32)
|
|
x = jnp.reshape(x, (16, 64))
|
|
y = jnp.arange(0, 1024, dtype=jnp.float32)
|
|
y = jnp.reshape(y, (16, 64))
|
|
custom_dll = DLL(major_to_minor=(0, 1), _tiling=((8, 128),))
|
|
custom_dll_linear = DLL(major_to_minor=(0, 1), _tiling=((1,),))
|
|
x = jax.device_put(x, Layout(custom_dll, sharding))
|
|
y = jax.device_put(y, Layout(custom_dll_linear, p_sharding))
|
|
|
|
x1 = jnp.arange(0, 1024, dtype=jnp.float32)
|
|
x1 = jnp.reshape(x1, (16, 64))
|
|
y1 = jnp.arange(0, 1024, dtype=jnp.float32)
|
|
y1 = jnp.reshape(y1, (16, 64))
|
|
|
|
jit_fn = jax.jit(
|
|
test_fn,
|
|
out_shardings=(
|
|
Layout(custom_dll, sharding),
|
|
Layout(custom_dll_linear, p_sharding),
|
|
),
|
|
)
|
|
x_out, y_out = jit_fn(x, y)
|
|
self.assertArraysEqual(x_out, x1 * x1)
|
|
self.assertArraysEqual(y_out, y1 + y1)
|
|
|
|
def test_compute_offload_mesh_with_linear_layout(self):
|
|
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
|
sharding = NamedSharding(mesh, P("x", "y"))
|
|
p_sharding = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host")
|
|
|
|
@compute_on("device_host")
|
|
@jax.jit
|
|
def host_fn(x_in, y_in):
|
|
return x_in * x_in, y_in + y_in
|
|
|
|
def test_fn(x_in, y_in):
|
|
x_out, y_out = host_fn(x_in, y_in)
|
|
return x_out, y_out
|
|
|
|
x = jnp.arange(0, 2048, dtype=jnp.float32)
|
|
x = jnp.reshape(x, (32, 64))
|
|
y = jnp.arange(0, 2048, dtype=jnp.float32)
|
|
y = jnp.reshape(y, (32, 64))
|
|
custom_dll = DLL(major_to_minor=(0, 1), _tiling=((8, 128),))
|
|
custom_dll_linear = DLL(major_to_minor=(0, 1), _tiling=((1,),))
|
|
x = jax.device_put(x, Layout(custom_dll, sharding))
|
|
y = jax.device_put(y, Layout(custom_dll_linear, p_sharding))
|
|
|
|
x1 = jnp.arange(0, 2048, dtype=jnp.float32)
|
|
x1 = jnp.reshape(x1, (32, 64))
|
|
y1 = jnp.arange(0, 2048, dtype=jnp.float32)
|
|
y1 = jnp.reshape(y1, (32, 64))
|
|
|
|
jit_fn = jax.jit(
|
|
test_fn,
|
|
out_shardings=(
|
|
Layout(custom_dll, sharding),
|
|
Layout(custom_dll_linear, p_sharding),
|
|
),
|
|
)
|
|
x_out, y_out = jit_fn(x, y)
|
|
self.assertArraysEqual(x_out, x1 * x1)
|
|
self.assertArraysEqual(y_out, y1 + y1)
|
|
|
|
def test_compute_on_cache_miss(self):
|
|
@jax.jit
|
|
def f(x):
|
|
return x * 2
|
|
|
|
inp = jnp.arange(10)
|
|
with jtu.count_jit_tracing_cache_miss() as count:
|
|
with compute_on('device_host'):
|
|
f(inp)
|
|
|
|
with compute_on('device'):
|
|
f(inp)
|
|
|
|
# 2 for `f` and `2` for `mul` (compute type changes for `mul`)
|
|
self.assertEqual(count(), 4)
|
|
|
|
def test_offload_take_host(self):
|
|
# TODO(apaszke): Remove after 12 weeks have passed.
|
|
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
|
self.skipTest("Requires libtpu built after 2024-12-19")
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def peer_forward(x, experts, indices, scores):
|
|
w = jnp.take(experts, indices.astype(int), axis=0)
|
|
w_gate, w_down, w_up = w[..., 0], w[..., 1], w[..., 2]
|
|
g = jnp.einsum('btd, bthkd->bthk', x, w_gate)
|
|
x = jnp.einsum('btd, bthkd->bthk', x, w_down)
|
|
x = x * jax.nn.gelu(g) * scores
|
|
return jnp.einsum('bthk, bthkd->btd', x, w_up)
|
|
|
|
x = jnp.ones((16, 4, 32))
|
|
experts = jnp.ones((128, 32, 3))
|
|
indices = jnp.ones((16, 4, 4, 2), dtype=jnp.int32)
|
|
scores = jnp.ones((16, 4, 4, 2))
|
|
jax.jit(peer_forward)(x, experts, indices, scores) # doesn't crash
|
|
|
|
|
|
class StreamAnnotationTest(jtu.JaxTestCase):
|
|
|
|
def test_stream_annotation_inside_shmap(self):
|
|
if not jtu.test_device_matches(["gpu"]):
|
|
self.skipTest("Stream annotation is only supported on GPU.")
|
|
mesh = jtu.create_mesh((2,), ('x',))
|
|
s = NamedSharding(mesh, P('x'))
|
|
np_inp = np.ones((8,))
|
|
arr1 = jax.device_put(np_inp, s)
|
|
arr2 = jax.device_put(np_inp, s)
|
|
|
|
@compute_on('gpu_stream:1')
|
|
@jax.jit
|
|
def g(x, y):
|
|
return x * y
|
|
|
|
@compute_on('gpu_stream:2')
|
|
@jax.jit
|
|
def h(x, y):
|
|
return x * y
|
|
|
|
def f(x, y):
|
|
z = g(x, y)
|
|
w = h(3 * x, 2 * y)
|
|
return z + w
|
|
|
|
out = jax.jit(shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')),
|
|
out_specs=P('x')))(arr1, arr2)
|
|
self.assertArraysEqual(out, arr1 * 7)
|
|
|
|
|
|
class ActivationOffloadingTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
if not jtu.test_device_matches(["tpu", "gpu"]):
|
|
self.skipTest("Memories do not work on CPU backend.")
|
|
super().setUp()
|
|
|
|
def test_remat_jaxpr_offloadable(self):
|
|
mesh = jtu.create_mesh((2,), ("x",))
|
|
inp = jax.device_put(np.arange(16.), NamedSharding(mesh, P("x")))
|
|
|
|
def policy(prim, *avals, **params):
|
|
return Offloadable(src="device", dst="pinned_host")
|
|
|
|
@functools.partial(remat, policy=policy)
|
|
def f(x):
|
|
x = jnp.sin(x)
|
|
x = jnp.sin(x)
|
|
x = jnp.sin(x)
|
|
return jnp.sum(x)
|
|
|
|
fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, inp)
|
|
|
|
self.assertLen(fwd_jaxpr.out_avals, 4) # 1 output, 3 offloaded residuals
|
|
fwd_mem_kind_count = str(fwd_jaxpr).count(
|
|
"TransferToMemoryKind(memory_kind='pinned_host')")
|
|
self.assertEqual(fwd_mem_kind_count, 3)
|
|
|
|
self.assertLen(bwd_jaxpr.in_avals, 4) # 3 offloaded residuals, 1 input
|
|
bwd_mem_kind_count = str(bwd_jaxpr).count(
|
|
"TransferToMemoryKind(memory_kind='device')")
|
|
self.assertEqual(bwd_mem_kind_count, 3)
|
|
|
|
# Execution test.
|
|
f = jax.jit(jax.grad(f))
|
|
f(inp) # doesn't crash
|
|
|
|
compiled_f = f.lower(inp).compile()
|
|
|
|
compiled_text = compiled_f.as_text()
|
|
if compiled_text is not None:
|
|
self.assertIn('S(5)', compiled_text)
|
|
self.assertRegex(compiled_text, r"copy-start.*S\(5\)")
|
|
self.assertRegex(compiled_text, r"copy-done.*S\(5\)")
|
|
|
|
compiled_stats = compiled_f.memory_analysis()
|
|
if compiled_stats is not None:
|
|
if jtu.pjrt_c_api_version_at_least(0, 43):
|
|
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
|
|
|
|
def test_remat_scan_jaxpr_offloadable(self):
|
|
mesh = jtu.create_mesh((2,), ("x",))
|
|
shape = (256, 128)
|
|
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
s = NamedSharding(mesh, P("x"))
|
|
inp = jax.device_put(np_inp, s)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, "The names should be exclusive and should not intersect"):
|
|
jax.checkpoint_policies.save_and_offload_only_these_names(
|
|
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["y", "w"],
|
|
offload_src="device", offload_dst="pinned_host")
|
|
|
|
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
|
|
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z", "w"],
|
|
offload_src='device', offload_dst='pinned_host')
|
|
|
|
@functools.partial(remat, policy=policy)
|
|
def f(x):
|
|
def g(ys, _):
|
|
y, _ = ys
|
|
y = checkpoint_name(jnp.sin(y), "y")
|
|
z = checkpoint_name(jnp.sin(y), "z")
|
|
z = jax.lax.with_sharding_constraint(z, s)
|
|
w = checkpoint_name(jnp.sin(z), "w")
|
|
return (w, jnp.sum(w)), None
|
|
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
|
|
return scan_out
|
|
|
|
fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, inp)
|
|
|
|
self.assertLen(fwd_jaxpr.out_avals, 5) # 2 output, 3 offloaded residuals
|
|
fwd_mem_kind_count = str(fwd_jaxpr).count(
|
|
"TransferToMemoryKind(memory_kind='pinned_host')")
|
|
self.assertEqual(fwd_mem_kind_count, 2)
|
|
|
|
self.assertLen(bwd_jaxpr.in_avals, 5) # 3 offloaded residuals, 2 input
|
|
bwd_mem_kind_count = str(bwd_jaxpr).count(
|
|
"TransferToMemoryKind(memory_kind='device')")
|
|
self.assertEqual(bwd_mem_kind_count, 2)
|
|
|
|
f = jax.jit(jax.grad(f))
|
|
f(inp) # doesn't crash
|
|
|
|
compiled_f = f.lower(inp).compile()
|
|
|
|
compiled_text = compiled_f.as_text()
|
|
if compiled_text is not None:
|
|
self.assertIn('S(5)', compiled_text)
|
|
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
|
|
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")
|
|
|
|
compiled_stats = compiled_f.memory_analysis()
|
|
if compiled_stats is not None:
|
|
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
|
|
|
|
def test_remat_scan_layout_change_offloadable(self):
|
|
mesh = jtu.create_mesh((2,), ("x",))
|
|
shape = (256, 128)
|
|
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
s = NamedSharding(mesh, P("x"))
|
|
inp = jax.device_put(np_inp, s)
|
|
|
|
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
|
|
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z", "w"],
|
|
offload_src='device', offload_dst='pinned_host')
|
|
|
|
@functools.partial(remat, policy=policy)
|
|
def f(x):
|
|
def g(ys, _):
|
|
y, _ = ys
|
|
y = checkpoint_name(jnp.sin(y), "y")
|
|
z = checkpoint_name(jnp.sin(y), "z")
|
|
z = jax.lax.with_sharding_constraint(z, s)
|
|
z = z.T
|
|
w = checkpoint_name(jnp.sin(z), "w")
|
|
return (w.T, jnp.sum(w)), None
|
|
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
|
|
return scan_out
|
|
|
|
f = jax.jit(jax.grad(f))
|
|
f(inp) # doesn't crash
|
|
|
|
compiled_f = f.lower(inp).compile()
|
|
|
|
compiled_text = compiled_f.as_text()
|
|
if compiled_text is not None:
|
|
self.assertIn('S(5)', compiled_text)
|
|
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
|
|
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")
|
|
self.assertRegex(compiled_text, r"dynamic-update-slice-start.*S\(5\)")
|
|
self.assertRegex(compiled_text, r"dynamic-update-slice-done.*S\(5\)")
|
|
self.assertRegex(compiled_text, r"dynamic-slice-start.*S\(5\)")
|
|
self.assertIn("dynamic-slice-start", compiled_text)
|
|
|
|
compiled_stats = compiled_f.memory_analysis()
|
|
if compiled_stats is not None:
|
|
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
|
|
|
|
def test_remat_checkpoint_dots_with_no_batch_dims(self):
|
|
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
|
|
"device", "pinned_host")
|
|
|
|
@functools.partial(new_checkpoint, policy=policy)
|
|
def f(x):
|
|
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
|
x = jnp.sin(x)
|
|
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
|
x = jnp.sin(x)
|
|
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
|
|
x = jnp.sin(x)
|
|
x = jnp.sum(x)
|
|
return x
|
|
|
|
inp = jnp.ones((2, 2))
|
|
f = jax.jit(jax.grad(f))
|
|
f(inp) # doesn't crash
|
|
|
|
compiled_f = f.lower(inp).compile()
|
|
|
|
compiled_text = compiled_f.as_text()
|
|
if compiled_text is not None:
|
|
self.assertIn('S(5)', compiled_text)
|
|
self.assertRegex(compiled_text, r"copy-start.*S\(5\)")
|
|
self.assertRegex(compiled_text, r"copy-done.*S\(5\)")
|
|
|
|
compiled_stats = compiled_f.memory_analysis()
|
|
if compiled_stats is not None:
|
|
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
|
|
|
|
def test_primitive_with_multiple_outputs(self):
|
|
# Test for https://github.com/jax-ml/jax/issues/25841
|
|
shape = (128,)
|
|
inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
def policy(prim, *args, **kwargs):
|
|
del args, kwargs
|
|
if prim.multiple_results:
|
|
return Offloadable("device", "pinned_host")
|
|
return Recompute
|
|
|
|
@functools.partial(remat, policy=policy)
|
|
def test_fn(x):
|
|
# Need any primitive with multiple outputs and a non-trivial grad.
|
|
x1, _ = jax.lax.approx_max_k(x, k=2)
|
|
return jnp.sum(x1)
|
|
|
|
fn = jax.grad(test_fn)
|
|
jax.jit(fn)(inp) # doesn't crash
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|