Move memories_test.py to JAX

PiperOrigin-RevId: 564551723
This commit is contained in:
Yash Katariya 2023-09-11 17:41:18 -07:00 committed by jax authors
parent 3e06dc8b77
commit 76a5dc3cac
2 changed files with 989 additions and 0 deletions

View File

@ -185,6 +185,17 @@ jax_test(
],
)
jax_test(
name = "memories_test",
srcs = ["memories_test.py"],
disable_configs = [
"tpu_se",
],
shard_count = {
"tpu": 5,
},
)
jax_test(
name = "pjit_test",
srcs = ["pjit_test.py"],

978
tests/memories_test.py Normal file
View File

@ -0,0 +1,978 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import math
import warnings
from absl.testing import absltest
from absl.testing import parameterized
import unittest
import jax
from jax._src import test_util as jtu
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax._src.sharding_impls import (NamedSharding, PositionalSharding,
SingleDeviceSharding, GSPMDSharding,
TransferToMemoryKind,
common_devices_indices_map)
import numpy as np
from jax import config
config.parse_flags_with_absl()
def get_memory_kinds_from_executable(f, args):
compiled = f.lower(*args).compile()
return compiled.runtime_executable().get_output_memory_kinds()[0]
def _create_inputs(shape, pspec, mem_kind=None):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
np_inp = np.arange(math.prod(shape)).reshape(shape)
s = NamedSharding(mesh, pspec, memory_kind=mem_kind)
inp = jax.device_put(np_inp, s)
return mesh, s, np_inp, inp
# Tests TODO
# * wsc with memory_kinds
# * shard_map
# * AOT
# * autodiff tests (jtu.check_grads)
# * scan tests
# * jaxpr checks for primitive running on different mem kinds
# * nested jit
class MemoriesTest(jtu.JaxTestCase):
def setUp(self):
if jtu.device_under_test() in ("cpu", "gpu"):
self.skipTest("Memories do not work on CPU and GPU backends yet.")
super().setUp()
def _check_mem_kind(self, executable_kind, out_sharding, expected_kind):
out_kind = out_sharding.memory_kind
self.assertEqual(executable_kind, out_kind)
self.assertEqual(out_kind, expected_kind)
self.assertEqual(executable_kind, expected_kind)
def _check_device_put_addressable_shards(
self, out, inp, expected_sharding, expected_mem_kind, index=True):
self.assertArraysEqual(out, inp)
self.assertEqual(out.sharding, expected_sharding)
self.assertEqual(out.sharding.memory_kind, expected_mem_kind)
for s in out.addressable_shards:
if index:
self.assertArraysEqual(s.data, inp[s.index])
else:
self.assertArraysEqual(s.data, inp)
self.assertEqual(s.data.sharding.memory_kind, expected_mem_kind)
@parameterized.named_parameters(
("named_sharding", "named_sharding"),
("positional_sharding", "positional_sharding"),
("single_device_sharding", "single_device_sharding"),
("gspmd_sharding", "gspmd_sharding"),
)
def test_canonicalize_memory_kind(self, name):
if name == "named_sharding":
mesh = jtu.create_global_mesh((1,), "x")
ns = NamedSharding(mesh, P("x"))
self.assertEqual(ns.memory_kind, "tpu_hbm")
elif name == "positional_sharding":
ps = PositionalSharding(jax.devices())
self.assertEqual(ps.memory_kind, "tpu_hbm")
elif name == "single_device_sharding":
ss = SingleDeviceSharding(jax.devices()[0])
self.assertEqual(ss.memory_kind, "tpu_hbm")
else:
assert name == "gspmd_sharding"
gs = GSPMDSharding.get_replicated(jax.devices())
self.assertEqual(gs.memory_kind, "tpu_hbm")
@parameterized.named_parameters(
("named_sharding", "named_sharding"),
("positional_sharding", "positional_sharding"),
("single_device_sharding", "single_device_sharding"),
("gspmd_sharding", "gspmd_sharding"),
)
def test_wrong_memory_kind(self, name):
if name == "named_sharding":
with self.assertRaisesRegex(
ValueError, "Could not find memory addressable by device TPU.*"
):
mesh = jtu.create_global_mesh((8,), ("x",))
NamedSharding(mesh, P("x"), memory_kind="hbm")
elif name == "positional_sharding":
with self.assertRaisesRegex(
ValueError, "Could not find memory addressable by device TPU.*"
):
PositionalSharding(jax.devices(), memory_kind="gpu_hbm")
elif name == "single_device_sharding":
with self.assertRaisesRegex(
ValueError,
"Could not find memory addressable by device TPU.*Device TPU.*"
" can address the following memory kinds: tpu_hbm, unpinned_host.*",
):
SingleDeviceSharding(jax.devices()[0], memory_kind="host")
else:
assert name == "gspmd_sharding"
with self.assertRaisesRegex(
ValueError, "Could not find memory addressable by device TPU.*"
):
GSPMDSharding.get_replicated(jax.devices(), memory_kind="my_host")
@parameterized.named_parameters(
("named_sharding", "named_sharding"),
("positional_sharding", "positional_sharding"),
("single_device_sharding", "single_device_sharding"),
("gspmd_sharding", "gspmd_sharding"),
)
def test_correct_tpu_memory_kind(self, name):
if name == "named_sharding":
mesh = jtu.create_global_mesh((8,), ("x",))
NamedSharding(mesh, P("x"), memory_kind="tpu_hbm")
elif name == "positional_sharding":
PositionalSharding(jax.devices(), memory_kind="tpu_hbm")
elif name == "single_device_sharding":
SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host")
else:
assert name == "gspmd_sharding"
GSPMDSharding.get_replicated(jax.devices(), memory_kind="unpinned_host")
@parameterized.named_parameters(
("named_sharding", "named_sharding"),
("positional_sharding", "positional_sharding"),
("single_device_sharding", "single_device_sharding"),
("gspmd_sharding", "gspmd_sharding"),
)
def test_sharding_eq(self, name):
if name == "named_sharding":
mesh = jtu.create_global_mesh((8,), ("x",))
s1 = NamedSharding(mesh, P("x"))
s2 = NamedSharding(mesh, P("x"), memory_kind="tpu_hbm")
self.assertEqual(s1, s2)
elif name == "positional_sharding":
s1 = PositionalSharding(jax.devices())
s2 = PositionalSharding(jax.devices(), memory_kind="tpu_hbm")
self.assertEqual(s1, s2)
elif name == "single_device_sharding":
s1 = SingleDeviceSharding(jax.devices()[0])
s2 = SingleDeviceSharding(jax.devices()[0], memory_kind="tpu_hbm")
self.assertEqual(s1, s2)
elif name == "gspmd_sharding":
s1 = GSPMDSharding.get_replicated(jax.devices())
s2 = GSPMDSharding.get_replicated(jax.devices(), memory_kind="tpu_hbm")
self.assertEqual(s1, s2)
def test_sharding_equivalent(self):
mesh = jtu.create_global_mesh((8,), ("x",))
ndim = 2
ns1 = NamedSharding(mesh, P("x"))
gs1 = GSPMDSharding(
tuple(mesh.devices.flat),
ns1._to_xla_hlo_sharding(ndim),
memory_kind="tpu_hbm",
)
self.assertTrue(ns1.is_equivalent_to(gs1, ndim))
ns2 = NamedSharding(mesh, P("x"), memory_kind="tpu_hbm")
gs2 = GSPMDSharding(
tuple(mesh.devices.flat), ns2._to_xla_hlo_sharding(ndim)
)
self.assertTrue(ns2.is_equivalent_to(gs2, ndim))
def test_default_memory_kind(self):
dev = jax.devices()[0]
self.assertEqual(dev.default_memory().kind, "tpu_hbm")
def test_jit_memory_transfer_to_host_middle(self):
_, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"), mem_kind="tpu_hbm")
@jax.jit
def f(x):
x = x * 2
y = jax.device_put(x, s.with_memory_kind("unpinned_host"))
z = y * 3
a = jax.device_put(z, s.with_memory_kind("tpu_hbm"))
return a * 4, a
out1, out2 = f(inp)
executable_mk = get_memory_kinds_from_executable(f, [inp])
self.assertArraysEqual(out1, np_inp * 24)
self.assertArraysEqual(out2, np_inp * 6)
self.assertEqual(out1.sharding, s)
self.assertEqual(out2.sharding, s)
self._check_mem_kind(executable_mk[0], out1.sharding, "tpu_hbm")
self._check_mem_kind(executable_mk[1], out2.sharding, "tpu_hbm")
def test_addressable_shards_mem_kind(self):
_, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
@jax.jit
def f(x):
x = jax.device_put(x, s.with_memory_kind("unpinned_host"))
return x * 2
out = f(inp)
executable_mk = get_memory_kinds_from_executable(f, [inp])
expected_out = np_inp * 2
self.assertArraysEqual(out, expected_out)
self.assertEqual(out.sharding, s.with_memory_kind(("unpinned_host")))
self._check_mem_kind(executable_mk[0], out.sharding, "unpinned_host")
for s in out.addressable_shards:
self.assertArraysEqual(s.data, expected_out[s.index])
self._check_mem_kind(executable_mk[0], s.data.sharding, "unpinned_host")
def test_jit_host_multi_outputs(self):
_, s, np_inp, inp = _create_inputs((8, 2), P("x"))
@jax.jit
def f(x, y):
x, y = jnp.sin(x), jnp.cos(y)
x = jax.device_put(x, s.with_memory_kind("unpinned_host"))
y = jax.device_put(y, s.with_memory_kind("tpu_hbm"))
return x, y
out1, out2 = f(inp, inp)
self.assertArraysAllClose(out1, np.sin(np_inp))
self.assertArraysAllClose(out2, np.cos(np_inp))
self.assertEqual(out1.sharding, s.with_memory_kind("unpinned_host"))
self.assertEqual(out2.sharding, s.with_memory_kind("tpu_hbm"))
def test_jit_explicit_tpu_hbm(self):
_, s, np_inp, inp = _create_inputs((8, 2), P("x"), mem_kind="tpu_hbm")
@jax.jit
def f(x):
return x * 2
out = f(inp)
executable_mk = get_memory_kinds_from_executable(f, [inp])
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, np_inp * 2)
self._check_mem_kind(executable_mk[0], out.sharding, "tpu_hbm")
def test_same_constant_value_on_different_memories(self):
_, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"), mem_kind="tpu_hbm")
@jax.jit
def f(x):
x = x * 2
y = jax.device_put(x, s.with_memory_kind("unpinned_host"))
z = y * 2
a = jax.device_put(z, s.with_memory_kind("tpu_hbm"))
return a * 2, z
out1, out2 = f(inp)
executable_mk = get_memory_kinds_from_executable(f, [inp])
self.assertArraysEqual(out1, np_inp * 8)
self.assertArraysEqual(out2, np_inp * 4)
self._check_mem_kind(executable_mk[0], out1.sharding, "tpu_hbm")
self._check_mem_kind(executable_mk[1], out2.sharding, "unpinned_host")
def test_jit_out_shardings(self):
_, s, _, inp = _create_inputs((8, 2), P("x", "y"))
def _check(fun):
executable_mk = get_memory_kinds_from_executable(fun, [inp])
outs = fun(inp)
for o, m in zip(outs, executable_mk):
self._check_mem_kind(m, o.sharding, "unpinned_host")
self.assertEqual(o.sharding, s.with_memory_kind("unpinned_host"))
@functools.partial(
jax.jit, out_shardings=s.with_memory_kind("unpinned_host")
)
def f(x):
return x * 2, x * 2
_check(f)
@functools.partial(
jax.jit, out_shardings=s.with_memory_kind("unpinned_host")
)
def h(x):
return x, x * 3
_check(h)
@functools.partial(
jax.jit, out_shardings=s.with_memory_kind("unpinned_host")
)
def i(x):
return x, x
_check(i)
def test_jit_out_shardings_single_output(self):
mesh, _, _, inp = _create_inputs((8, 2), P("x", "y"))
out_s = NamedSharding(mesh, P(), memory_kind="unpinned_host")
@functools.partial(jax.jit, out_shardings=out_s)
def g(x):
return jnp.sum(x * 2)
out = g(inp)
self.assertEqual(out.sharding, out_s)
executable_mk = get_memory_kinds_from_executable(g, [inp])
self._check_mem_kind(executable_mk[0], out.sharding, "unpinned_host")
@jax.jit
def h(x):
x = jnp.sum(x * 2)
out = jax.device_put(x, out_s)
return out
out = h(inp)
self.assertEqual(out.sharding, out_s)
executable_mk = get_memory_kinds_from_executable(h, [inp])
self._check_mem_kind(executable_mk[0], out.sharding, "unpinned_host")
def test_jit_device_put_host_output(self):
_, s, _, inp = _create_inputs((8, 2), P("x", "y"))
def _check(fun):
executable_mk = get_memory_kinds_from_executable(fun, [inp])
outs = fun(inp)
for o, m in zip(outs, executable_mk):
self._check_mem_kind(m, o.sharding, "unpinned_host")
self.assertEqual(o.sharding, s.with_memory_kind("unpinned_host"))
@jax.jit
def f(x):
x = x * 2
out = jax.device_put(x, s.with_memory_kind("unpinned_host"))
return out, out
_check(f)
@jax.jit
def h(x):
x = x * 2
out = jax.device_put(x, s.with_memory_kind("unpinned_host"))
return out, out * 3
_check(h)
@jax.jit
def i(x):
x = x * 2
out = jax.device_put(x, s.with_memory_kind("unpinned_host"))
return out * 2, out * 2
_check(i)
def test_jit_in_shardings(self):
_, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
@functools.partial(
jax.jit, in_shardings=s.with_memory_kind("unpinned_host")
)
def f(x):
return x * 2
with self.assertRaisesRegex(
ValueError,
"Memory kinds passed to jax.jit does not match memory kind on the"
" respective arg. Got pjit memory kind: unpinned_host, arg memory kind:"
" tpu_hbm for arg shape.*",
):
f(jnp.arange(16).reshape(8, 2)) # uncommitted inp also raises error
with self.assertRaisesRegex(
ValueError,
"Memory kinds passed to jax.jit does not match memory kind on the"
" respective arg. Got pjit memory kind: unpinned_host, arg memory kind:"
" tpu_hbm for arg shape.*",
):
f(inp) # committed inp raises error.
@functools.partial(jax.jit, in_shardings=s.with_memory_kind("tpu_hbm"))
def g(x):
return x * 2
out = g(inp)
executable_kind = get_memory_kinds_from_executable(g, [inp])
self.assertArraysEqual(out, np_inp * 2)
self._check_mem_kind(executable_kind[0], out.sharding, "tpu_hbm")
def test_jit_in_out_shardings(self):
mesh, s, np_inp, inp = _create_inputs(
(8, 2), P("x", "y"), mem_kind="tpu_hbm"
)
out_s = NamedSharding(mesh, P(), memory_kind="tpu_hbm")
@functools.partial(jax.jit, in_shardings=s, out_shardings=out_s)
def f(x):
return jnp.sum(x)
out = f(inp)
executable_kind = get_memory_kinds_from_executable(f, [inp])
self.assertArraysEqual(out, np.sum(np_inp))
self._check_mem_kind(executable_kind[0], out.sharding, "tpu_hbm")
@functools.partial(
jax.jit,
in_shardings=s,
out_shardings=out_s.with_memory_kind("unpinned_host"),
)
def g(x):
return jnp.sum(x)
out = g(inp)
executable_kind = get_memory_kinds_from_executable(g, [inp])
self.assertArraysEqual(out, np.sum(np_inp))
self._check_mem_kind(executable_kind[0], out.sharding, "unpinned_host")
def test_device_put_different_devices(self):
_, _, _, inp = _create_inputs((8, 2), P("x", "y"))
@jax.jit
def f(x):
return jax.device_put(
x, SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host")
)
with self.assertRaisesRegex(
ValueError, "Received incompatible devices for jitted computation"
):
f(inp)
def test_jit_multiple_transfers(self):
mesh, _, np_inp, inp = _create_inputs((8, 2), P(None, "y"))
s2 = NamedSharding(mesh, P("x"))
inp2 = jax.device_put(np_inp, s2)
@jax.jit
def f(x, y):
a = x + y
b, c = jax.device_put((a, x), s2.with_memory_kind("unpinned_host"))
return b * c, y * 2
out1, out2 = f(inp, inp2)
executable_mem = get_memory_kinds_from_executable(f, [inp, inp2])
self.assertArraysEqual(out1, (np_inp + np_inp) * np_inp)
self.assertArraysEqual(out2, np_inp * 2)
self._check_mem_kind(executable_mem[0], out1.sharding, "unpinned_host")
self._check_mem_kind(executable_mem[1], out2.sharding, "tpu_hbm")
def test_jit_single_device_multi_output_host_mem(self):
inp = jnp.arange(16).reshape(8, 2)
@jax.jit
def f(x):
x = jax.device_put(
x, SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host")
)
return x * 2, x * 3
out1, out2 = f(inp)
executable_mem = get_memory_kinds_from_executable(f, [inp])
self.assertArraysEqual(out1, inp * 2)
self.assertArraysEqual(out2, inp * 3)
self._check_mem_kind(executable_mem[0], out1.sharding, "unpinned_host")
self._check_mem_kind(executable_mem[1], out2.sharding, "unpinned_host")
def test_jit_reshard(self):
mesh, _, np_inp, inp = _create_inputs((8, 2), P(None, "y"))
out_s = NamedSharding(mesh, P(("x", "y")), memory_kind="unpinned_host")
def _check(fun, inp):
out = fun(inp)
self.assertArraysEqual(out, np_inp * 2)
self.assertEqual(out.sharding, out_s)
executable_kind = get_memory_kinds_from_executable(fun, [inp])
self._check_mem_kind(executable_kind[0], out.sharding, "unpinned_host")
@functools.partial(jax.jit, out_shardings=out_s)
def f(x):
return x * 2
_check(f, inp)
@jax.jit
def g(x):
y = jax.device_put(x, out_s)
return y * 2
_check(g, inp)
def test_jit_cpp_cache_hit(self):
mesh, _, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
inp2 = jax.device_put(
np_inp, NamedSharding(mesh, P("x", "y"), memory_kind="tpu_hbm")
)
f = jax.jit(lambda x: x @ x.T)
with jtu.count_pjit_cpp_cache_miss() as count:
out = f(inp)
out2 = f(inp2)
self.assertEqual(count[0], 1)
self.assertArraysEqual(out, np_inp @ np_inp.T)
self.assertArraysEqual(out2, np_inp @ np_inp.T)
def test_jit_compilation_cache_hit(self):
mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
inp2 = jax.device_put(
np_inp,
GSPMDSharding(
tuple(mesh.devices.flat),
s._to_xla_hlo_sharding(inp.ndim),
memory_kind="tpu_hbm",
),
)
f = jax.jit(lambda x: x @ x.T)
with (
jtu.count_pjit_cpp_cache_miss() as cpp_count,
jtu.count_jit_and_pmap_compiles() as compile_count,
):
f(inp)
f(inp2)
self.assertEqual(cpp_count[0], 2)
self.assertEqual(compile_count[0], 1)
def test_jit_cpp_cache_output_hit(self):
_, _, _, inp = _create_inputs((8, 2), P("x"), mem_kind="tpu_hbm")
@jax.jit
def mul_two(x):
return x * 2
with jtu.count_pjit_cpp_cache_miss() as count:
out = mul_two(inp)
mul_two(out)
self.assertEqual(count[0], 1)
def test_jit_cache_miss(self):
mesh, _, np_inp, inp = _create_inputs(
(8, 2), P("x", "y"), mem_kind="tpu_hbm"
)
out_s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host")
@functools.partial(jax.jit, out_shardings=out_s_host)
def mul_three(x):
return x * 3
with (
jtu.count_pjit_cpp_cache_miss() as cpp_count,
jtu.count_jit_and_pmap_compiles() as compile_count,
):
out = mul_three(inp)
out2 = mul_three(out)
self.assertEqual(cpp_count[0], 2)
self.assertEqual(compile_count[0], 2)
self.assertEqual(out.sharding, out_s_host)
self.assertEqual(out2.sharding, out_s_host)
self.assertArraysEqual(out, np_inp * 3)
self.assertArraysEqual(out2, np_inp * 9)
executable_mk = get_memory_kinds_from_executable(mul_three, [inp])
self._check_mem_kind(executable_mk[0], out.sharding, "unpinned_host")
executable_mk2 = get_memory_kinds_from_executable(mul_three, [out])
self._check_mem_kind(executable_mk2[0], out2.sharding, "unpinned_host")
def test_jit_host_input_from_another_jit_output(self):
mesh, _, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
out_host_s = jax.sharding.NamedSharding(
mesh, P("x", "y"), memory_kind="unpinned_host"
)
@functools.partial(jax.jit, out_shardings=out_host_s)
def f(x):
return x * 2
out = f(inp)
self.assertEqual(out.sharding, out_host_s)
executable_kind = get_memory_kinds_from_executable(f, [inp])
self._check_mem_kind(executable_kind[0], out.sharding, "unpinned_host")
self.assertArraysEqual(out, np_inp * 2)
# Input to `f` is on host memory.
out2 = f(out)
self.assertEqual(out2.sharding, out_host_s)
executable_kind = get_memory_kinds_from_executable(f, [out])
self._check_mem_kind(executable_kind[0], out2.sharding, "unpinned_host")
self.assertArraysEqual(out2, np_inp * 4)
lowered_hlo = f.lower(out).as_text(dialect="hlo")
self.assertIn('_xla_buffer_placement="arg"', lowered_hlo)
def test_jit_cache_hit_with_default_and_specified_mem_kind(self):
_, s, np_inp, _ = _create_inputs((8, 2), P("x", "y"))
_, s2, np_inp2, _ = _create_inputs((8, 2), P("x", "y"), mem_kind="tpu_hbm")
def mul(x):
return x @ x.T
f = jax.jit(mul, in_shardings=s)
g = jax.jit(mul, in_shardings=s2)
with jtu.count_jit_and_pmap_compiles() as count:
out = f(np_inp)
out2 = g(np_inp2)
self.assertEqual(count[0], 1)
self.assertArraysEqual(out, np_inp @ np_inp.T)
self.assertArraysEqual(out2, np_inp2 @ np_inp2.T)
def test_sharding_devices_indices_map_cache_hit(self):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
shape = (8, 2)
s1 = NamedSharding(mesh, P("x", "y"))
s2 = NamedSharding(mesh, P("x", "y"), memory_kind="tpu_hbm")
s1.devices_indices_map(shape)
cache_info1 = common_devices_indices_map.cache_info()
s2.devices_indices_map(shape)
cache_info2 = common_devices_indices_map.cache_info()
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
def test_device_put_host_to_hbm(self):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("y"), memory_kind="unpinned_host")
np_inp = jnp.arange(16).reshape(8, 2)
@functools.partial(jax.jit, out_shardings=s_host)
def f(x):
return x
out_on_host = f(np_inp)
self.assertEqual(out_on_host.sharding, s_host)
s_hbm = s_host.with_memory_kind("tpu_hbm")
out_on_hbm = jax.device_put(out_on_host, s_hbm)
self._check_device_put_addressable_shards(
out_on_hbm, np_inp, s_hbm, "tpu_hbm")
def test_device_put_hbm_to_host(self):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("y"), memory_kind="unpinned_host")
inp = jnp.arange(16).reshape(8, 2)
out_on_host = jax.device_put(inp, s_host)
self._check_device_put_addressable_shards(
out_on_host, inp, s_host, "unpinned_host")
sharded_inp = jax.device_put(inp, s_host.with_memory_kind("tpu_hbm"))
sharded_out_on_host = jax.device_put(sharded_inp, s_host)
self._check_device_put_addressable_shards(
sharded_out_on_host, sharded_inp, s_host, "unpinned_host")
def test_device_put_different_device_and_memory_host_to_hbm(self):
if jax.device_count() < 3:
raise unittest.SkipTest("Test requires >=3 devices")
out_host0 = jax.device_put(
jnp.arange(8),
SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host"))
dev2 = jax.devices()[2]
out_hbm1 = jax.device_put(
out_host0, SingleDeviceSharding(dev2, memory_kind="tpu_hbm"))
self.assertEqual(out_hbm1.sharding.memory_kind, "tpu_hbm")
self.assertEqual(out_hbm1.sharding._device, dev2)
self.assertEqual(out_hbm1.addressable_shards[0].data.sharding._device, dev2)
self.assertEqual(
out_hbm1.addressable_shards[0].data.sharding.memory_kind, "tpu_hbm")
def test_device_put_different_device_and_memory_hbm_to_host(self):
if jax.device_count() < 3:
raise unittest.SkipTest("Test requires >=3 devices")
out_hbm0 = jnp.arange(8)
dev2 = jax.devices()[2]
out_host1 = jax.device_put(
out_hbm0, SingleDeviceSharding(dev2, memory_kind="unpinned_host"))
self.assertEqual(out_host1.sharding.memory_kind, "unpinned_host")
self.assertEqual(out_host1.sharding._device, dev2)
self.assertEqual(out_host1.addressable_shards[0].data.sharding._device,
dev2)
self.assertEqual(
out_host1.addressable_shards[0].data.sharding.memory_kind,
"unpinned_host")
def test_device_put_resharding(self):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host")
s_hbm = s_host.with_memory_kind("tpu_hbm")
np_inp = np.arange(16).reshape(8, 2)
# Reshard single device array on HBM to multi device on host
sds_inp_hbm = jax.device_put(
jnp.arange(16).reshape(8, 2),
SingleDeviceSharding(jax.devices()[0], memory_kind="tpu_hbm"))
# device_put on host
out_sharded_host = jax.device_put(sds_inp_hbm, s_host)
self._check_device_put_addressable_shards(
out_sharded_host, np_inp, s_host, "unpinned_host")
# Reshard single device array on host to multi device on hbm
sds_inp_host = jax.device_put(
jnp.arange(16).reshape(8, 2),
SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host"))
# device_put on hbm
out_sharded_hbm = jax.device_put(sds_inp_host, s_hbm)
self._check_device_put_addressable_shards(
out_sharded_hbm, np_inp, s_hbm, "tpu_hbm")
def test_jit_host_inputs_via_device_put_outside(self):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host")
s_hbm = s_host.with_memory_kind("tpu_hbm")
inp = jnp.arange(16).reshape(8, 2)
np_inp = np.arange(16).reshape(8, 2)
inp_host = jax.device_put(inp, s_host)
inp_hbm = jax.device_put(inp, s_hbm)
@jax.jit
def f(x, y):
return x * 2, y * 2
out_host, out_hbm = f(inp_host, inp_hbm)
self._check_device_put_addressable_shards(
out_host, np_inp * 2, s_host, "unpinned_host")
self._check_device_put_addressable_shards(
out_hbm, np_inp * 2, s_hbm, "tpu_hbm")
def test_device_put_numpy_array(self):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
np_inp = np.arange(16).reshape(8, 2)
s_hbm = NamedSharding(mesh, P(("x", "y")), memory_kind="tpu_hbm")
s_host = s_hbm.with_memory_kind("unpinned_host")
out_hbm = jax.device_put(np_inp, s_hbm)
self._check_device_put_addressable_shards(out_hbm, np_inp, s_hbm, "tpu_hbm")
out_host = jax.device_put(np_inp, s_host)
self._check_device_put_addressable_shards(
out_host, np_inp, s_host, "unpinned_host")
def test_device_put_numpy_scalar(self):
np_inp = np.float32(8)
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="tpu_hbm")
s_host = s_hbm.with_memory_kind("unpinned_host")
out_hbm = jax.device_put(np_inp, s_hbm)
self._check_device_put_addressable_shards(out_hbm, np_inp, s_hbm, "tpu_hbm")
out_host = jax.device_put(np_inp, s_host)
self._check_device_put_addressable_shards(
out_host, np_inp, s_host, "unpinned_host")
def test_device_put_python_scalar(self):
py_scalar = float(8)
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="tpu_hbm")
s_host = s_hbm.with_memory_kind("unpinned_host")
out_hbm = jax.device_put(py_scalar, s_hbm)
self._check_device_put_addressable_shards(
out_hbm, py_scalar, s_hbm, "tpu_hbm", index=False)
out_host = jax.device_put(py_scalar, s_host)
self._check_device_put_addressable_shards(
out_host, py_scalar, s_host, "unpinned_host", index=False)
def test_device_put_python_int(self):
py_inp = 8
s_hbm = SingleDeviceSharding(jax.devices()[0], memory_kind="tpu_hbm")
s_host = s_hbm.with_memory_kind("unpinned_host")
out_hbm = jax.device_put(py_inp, s_hbm)
self._check_device_put_addressable_shards(
out_hbm, py_inp, s_hbm, "tpu_hbm", index=False)
out_host = jax.device_put(py_inp, s_host)
self._check_device_put_addressable_shards(
out_host, py_inp, s_host, "unpinned_host", index=False)
def test_trivial_computation(self):
mesh = jtu.create_global_mesh((2, 1), ("x", "y"))
np_inp = np.arange(16).reshape(8, 2)
s_hbm = NamedSharding(mesh, P("x"))
inp = jax.device_put(np_inp, s_hbm)
f = jax.jit(lambda x: x)
out = f(inp)
self.assertArraysEqual(out, np_inp)
self.assertEqual(out.sharding, s_hbm)
s_host = NamedSharding(mesh, P(None, "x"), memory_kind="unpinned_host")
inp = jax.device_put(np_inp, s_host)
f = jax.jit(lambda x: x)
out = f(inp)
self.assertArraysEqual(out, np_inp)
self.assertEqual(out.sharding, s_host)
def test_no_donation_across_memory_kinds(self):
mesh = jtu.create_global_mesh((2, 1), ("x", "y"))
np_inp = np.arange(16).reshape(8, 2)
s_hbm = NamedSharding(mesh, P("x"))
s_host = s_hbm.with_memory_kind("unpinned_host")
inp = jax.device_put(np_inp, s_hbm)
@functools.partial(jax.jit, out_shardings=s_host, donate_argnums=0)
def f(x):
return x * 2
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
f(inp)
self.assertLen(w, 1)
self.assertTrue(issubclass(w[-1].category, UserWarning))
self.assertIn("Some donated buffers were not usable:", str(w[-1].message))
lowered_text = f.lower(inp).as_text("hlo")
self.assertNotIn("input_output_alias", lowered_text)
@parameterized.named_parameters(
("hbm_to_host", "tpu_hbm", "unpinned_host"),
("host_to_hbm", "unpinned_host", "tpu_hbm")
)
def test_device_put_memory_kind_no_sharding(self, inp_mem_kind, out_mem_kind):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P("x", "y"), memory_kind=inp_mem_kind)
inp = jax.device_put(np_inp, s)
@jax.jit
def f(x):
y = x @ x.T
z = jax.device_put(y, TransferToMemoryKind(out_mem_kind))
return z * 2
out = f(inp)
self._check_device_put_addressable_shards(
out, (np_inp @ np_inp.T) * 2,
NamedSharding(mesh, P("x"), memory_kind=out_mem_kind),
out_mem_kind)
executable_kind = get_memory_kinds_from_executable(f, [inp])
self._check_mem_kind(executable_kind[0], out.sharding, out_mem_kind)
@parameterized.named_parameters(
("hbm_to_host", "tpu_hbm", "unpinned_host"),
("host_to_hbm", "unpinned_host", "tpu_hbm")
)
def test_device_put_memory_kind_no_sharding_output(
self, inp_mem_kind, out_mem_kind):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P("x", "y"), memory_kind=inp_mem_kind)
inp = jax.device_put(np_inp, s)
@jax.jit
def f(x):
y = x @ x.T
return jax.device_put(y, TransferToMemoryKind(out_mem_kind))
out = f(inp)
self._check_device_put_addressable_shards(
out, np_inp @ np_inp.T,
NamedSharding(mesh, P("x"), memory_kind=out_mem_kind),
out_mem_kind)
executable_kind = get_memory_kinds_from_executable(f, [inp])
self._check_mem_kind(executable_kind[0], out.sharding, out_mem_kind)
@parameterized.named_parameters(
("hbm_to_host", "tpu_hbm", "unpinned_host"),
("host_to_hbm", "unpinned_host", "tpu_hbm")
)
def test_device_put_memory_kind_no_sharding_input(
self, inp_mem_kind, out_mem_kind):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P("x", "y"), memory_kind=inp_mem_kind)
inp = jax.device_put(np_inp, s)
@jax.jit
def f(x):
y = jax.device_put(x, TransferToMemoryKind(out_mem_kind))
return y
# committed sharded input.
out = f(inp)
self.assertTrue(out._committed)
self._check_device_put_addressable_shards(
out, np_inp, s.with_memory_kind(out_mem_kind), out_mem_kind)
s1 = SingleDeviceSharding(jax.devices()[1], memory_kind=inp_mem_kind)
committed_single_device_inp = jax.device_put(np_inp, s1)
out2 = f(committed_single_device_inp)
self.assertTrue(out2._committed)
self._check_device_put_addressable_shards(
out2, np_inp, s1.with_memory_kind(out_mem_kind), out_mem_kind)
@jax.jit
def g(x):
y = jax.device_put(x, TransferToMemoryKind(out_mem_kind))
return y
# Uncommitted input but output will be committed because of device_put.
out3 = g(np_inp)
self.assertTrue(out3._committed)
self._check_device_put_addressable_shards(
out3, np_inp,
SingleDeviceSharding(jax.devices()[0], memory_kind=out_mem_kind),
out_mem_kind)
@functools.partial(jax.jit, in_shardings=s)
def h(x):
y = jax.device_put(x, TransferToMemoryKind(out_mem_kind))
return y
out4 = h(np_inp)
self.assertTrue(out4._committed)
self._check_device_put_addressable_shards(
out4, np_inp, s.with_memory_kind(out_mem_kind), out_mem_kind)
def test_error_transfer_to_memory_kind_outside_jit(self):
with self.assertRaisesRegex(
ValueError,
"TransferToMemoryKind argument to jax.device_put can only be used"
" inside jax.jit"):
jax.device_put(np.arange(16), TransferToMemoryKind("tpu_hbm"))
def test_single_mem_kind_donation(self):
mesh = jtu.create_global_mesh((2,), "x")
@functools.partial(jax.jit, donate_argnums=0)
def f(inp1):
return inp1 * 2
x = jax.device_put(np.arange(16).reshape(8, 2), NamedSharding(mesh, P()))
f(x)
lowered_text = f.lower(x).as_text("hlo")
self.assertIn("input_output_alias", lowered_text)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())