mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
750 lines
25 KiB
Python
750 lines
25 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import math
|
|
from functools import partial
|
|
from absl.testing import absltest
|
|
import numpy as np
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax.sharding import NamedSharding, PartitionSpec as P, SingleDeviceSharding
|
|
from jax._src import config
|
|
from jax._src.layout import Layout, DeviceLocalLayout as DLL
|
|
from jax._src import test_util as jtu
|
|
from jax._src.util import safe_zip
|
|
from jax.experimental.compute_on import compute_on
|
|
|
|
config.parse_flags_with_absl()
|
|
jtu.request_cpu_devices(8)
|
|
|
|
|
|
class LayoutTest(jtu.JaxTestCase):
|
|
|
|
def test_auto_layout(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
shape1 = (128, 128)
|
|
shape2 = (128, 128)
|
|
s1 = NamedSharding(mesh, P('x', 'y'))
|
|
s2 = NamedSharding(mesh, P('x'))
|
|
|
|
def apply(x, y):
|
|
return x.T, y.T
|
|
|
|
def init(x, y):
|
|
return x * 2, y * 2
|
|
|
|
np_inp1 = np.arange(math.prod(shape1)).reshape(shape1)
|
|
np_inp2 = np.arange(math.prod(shape2)).reshape(shape2)
|
|
sds1 = jax.ShapeDtypeStruct(np_inp1.shape, np_inp1.dtype, sharding=s1)
|
|
sds2 = jax.ShapeDtypeStruct(np_inp2.shape, np_inp2.dtype, sharding=s2)
|
|
|
|
lowered_apply = jax.jit(apply, in_shardings=Layout(DLL.AUTO),
|
|
out_shardings=Layout(DLL.AUTO)).lower(sds1, sds2)
|
|
compiled_apply = lowered_apply.compile()
|
|
|
|
arg_layouts, kw_layouts = compiled_apply.input_layouts
|
|
self.assertEmpty(kw_layouts)
|
|
|
|
for i, o in zip(arg_layouts, compiled_apply.output_layouts):
|
|
self.assertEqual(i.device_local_layout.major_to_minor,
|
|
o.device_local_layout.major_to_minor[::-1])
|
|
|
|
init_compiled = jax.jit(
|
|
init, out_shardings=arg_layouts).lower(sds1, sds2).compile()
|
|
|
|
for i, o in zip(init_compiled.input_layouts[0],
|
|
init_compiled.output_layouts):
|
|
self.assertEqual(i, o)
|
|
|
|
arr1 = jax.device_put(np_inp1, s1)
|
|
arr2 = jax.device_put(np_inp2, s2)
|
|
|
|
with jtu.count_aot_jit_cpp_cache_miss() as init_count:
|
|
init_out = init_compiled(arr1, arr2)
|
|
init_compiled(arr1, arr2)
|
|
self.assertEqual(init_count(), 1)
|
|
|
|
self.assertEqual(init_out[0].layout, init_compiled.output_layouts[0])
|
|
self.assertEqual(init_out[1].layout, init_compiled.output_layouts[1])
|
|
|
|
with jtu.count_aot_jit_cpp_cache_miss() as apply_count:
|
|
apply_out = compiled_apply(*init_out)
|
|
compiled_apply(*init_out)
|
|
self.assertEqual(apply_count(), 1)
|
|
|
|
self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts[0])
|
|
self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts[1])
|
|
|
|
self.assertTupleEqual(apply_out[0].layout.device_local_layout.major_to_minor,
|
|
init_out[0].layout.device_local_layout.major_to_minor[::-1])
|
|
self.assertTupleEqual(apply_out[1].layout.device_local_layout.major_to_minor,
|
|
init_out[1].layout.device_local_layout.major_to_minor[::-1])
|
|
|
|
self.assertArraysEqual(init_out[0], np_inp1 * 2)
|
|
self.assertArraysEqual(init_out[1], np_inp2 * 2)
|
|
self.assertArraysEqual(apply_out[0], (np_inp1 * 2).T)
|
|
self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T)
|
|
|
|
def test_default_layout(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
shape = (4, 4, 2)
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
def f(x):
|
|
return x.T
|
|
|
|
lowered = jax.jit(f, in_shardings=None, out_shardings=None).lower(sds)
|
|
compiled = lowered.compile()
|
|
out = compiled(arr)
|
|
|
|
self.assertTupleEqual(
|
|
compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1],
|
|
(2, 1, 0))
|
|
self.assertTupleEqual(
|
|
compiled.output_layouts.device_local_layout.major_to_minor[::-1],
|
|
(2, 1, 0))
|
|
self.assertArraysEqual(out, np_inp.T)
|
|
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
|
|
|
|
compiled_auto = jax.jit(f, in_shardings=Layout(DLL.AUTO),
|
|
out_shardings=Layout(DLL.AUTO)).lower(sds).compile()
|
|
self.assertTupleEqual(
|
|
compiled_auto.input_layouts[0][0].device_local_layout.major_to_minor[::-1],
|
|
(2, 1, 0))
|
|
self.assertTupleEqual(
|
|
compiled_auto.output_layouts.device_local_layout.major_to_minor[::-1],
|
|
(0, 1, 2))
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, "jax.jit` does not accept device-local layouts directly"):
|
|
jax.jit(f, in_shardings=DLL.AUTO,
|
|
out_shardings=DLL.AUTO).lower(sds).compile()
|
|
|
|
def test_in_layouts_out_layouts(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
shape = (8, 8)
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
def f(x):
|
|
return x.T
|
|
|
|
compiled = jax.jit(f, in_shardings=Layout(),
|
|
out_shardings=Layout(DLL.AUTO)).lower(arr).compile()
|
|
self.assertTupleEqual(
|
|
compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1],
|
|
(1, 0))
|
|
self.assertTupleEqual(
|
|
compiled.output_layouts.device_local_layout.major_to_minor[::-1],
|
|
(0, 1))
|
|
|
|
out = compiled(arr)
|
|
self.assertArraysEqual(out, np_inp.T)
|
|
self.assertEqual(out.layout, compiled.output_layouts)
|
|
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x')))
|
|
|
|
def test_sharding_and_layouts(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
shape = (4, 8)
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
compiled = jax.jit(lambda x: x.T, in_shardings=Layout(DLL.AUTO, s),
|
|
out_shardings=Layout(DLL.AUTO, s)).lower(np_inp).compile()
|
|
out = compiled(np_inp)
|
|
self.assertTupleEqual(
|
|
compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1],
|
|
(1, 0))
|
|
if not jtu.test_device_matches(['cpu']):
|
|
self.assertTupleEqual(
|
|
compiled.output_layouts.device_local_layout.major_to_minor[::-1],
|
|
(0, 1))
|
|
self.assertArraysEqual(out, np_inp.T)
|
|
self.assertEqual(out.sharding, s)
|
|
|
|
def test_dce_in_layouts(self):
|
|
def f(x, y, z, a, b, c):
|
|
return z * 2, b.T
|
|
|
|
shape = (8, 2)
|
|
inps = [np.arange(math.prod(shape)).reshape(shape)] * 6
|
|
compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO),
|
|
out_shardings=Layout(DLL.AUTO)).lower(*inps).compile()
|
|
arg_layouts, _ = compiled.input_layouts
|
|
out1, out2 = compiled(*inps)
|
|
|
|
compiled2 = jax.jit(f, in_shardings=arg_layouts).lower(*inps).compile()
|
|
out3, out4 = compiled2(*inps)
|
|
|
|
for l1, l2 in safe_zip(arg_layouts, compiled2.input_layouts[0]):
|
|
self.assertEqual(l1, l2)
|
|
|
|
self.assertArraysEqual(out1, out3)
|
|
self.assertArraysEqual(out2, out4)
|
|
|
|
arrs = [jax.device_put(i, l) for i, l in zip(inps, arg_layouts)]
|
|
out5, out6 = jax.jit(f)(*arrs)
|
|
self.assertArraysEqual(out1, out5)
|
|
self.assertArraysEqual(out2, out6)
|
|
|
|
def test_no_error_dced_args(self):
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
|
shape = (8, 2)
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
arr1 = jax.device_put(np_inp, s)
|
|
arr2 = jax.device_put(np_inp, s)
|
|
arrs = [arr1, arr2]
|
|
|
|
def f(x, y):
|
|
return x * 2
|
|
|
|
jf = jax.jit(f, in_shardings=Layout(DLL.AUTO, s),
|
|
out_shardings=Layout(DLL.AUTO, s))
|
|
compiled = jf.lower(np_inp, np_inp).compile()
|
|
arg_layouts, _ = compiled.input_layouts
|
|
arrs = [jax.device_put(i, l) for i, l in zip(arrs, arg_layouts)]
|
|
compiled(*arrs)
|
|
|
|
def test_aot_layout_mismatch(self):
|
|
if jtu.test_device_matches(['cpu', 'gpu']):
|
|
# The test fails on GPU because the compilation with both input and
|
|
# output set to auto layout is underspecified. The GPU compiler chooses
|
|
# the default layout as the input layout and that choice does not
|
|
# raise an exception.
|
|
self.skipTest('This test does not work on CPU or GPU backends.')
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
shape = (256, 4, 2)
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
s = NamedSharding(mesh, P('x'))
|
|
|
|
sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
def f(x):
|
|
return (x * 2).T
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'Layout passed to jit does not match the layout on the respective arg'):
|
|
jax.jit(f, in_shardings=Layout(DLL.AUTO)).lower(arr)
|
|
|
|
compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO),
|
|
out_shardings=Layout(DLL.AUTO)).lower(sds).compile()
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r'Compiled object called with input layout\(s\) does'
|
|
r' not match the layout\(s\) the computation was'
|
|
' compiled with'):
|
|
compiled(arr)
|
|
|
|
@jtu.ignore_warning(category=DeprecationWarning,
|
|
message="backend and device argument")
|
|
def test_cpu_default_backend_layout(self):
|
|
inp = jax.device_put(np.ones((8, 8)), device=jax.devices('cpu')[0])
|
|
out_cpu = jax.jit(jnp.dot)(inp, inp)
|
|
|
|
jax.jit(jnp.dot, backend=jax.default_backend()).lower(
|
|
out_cpu, out_cpu).compile() # doesn't crash
|
|
|
|
def test_device_put_concrete_layout(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
shape = (8, 128)
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
compiled = jax.jit(
|
|
lambda x: x * 2, out_shardings=Layout(DLL.AUTO)).lower(arr).compile()
|
|
col = compiled.output_layouts
|
|
|
|
out = jax.device_put(np_inp, col)
|
|
self.assertEqual(out.layout, col)
|
|
self.assertArraysEqual(out, np_inp)
|
|
for s in out.addressable_shards:
|
|
self.assertEqual(out.layout.device_local_layout,
|
|
s.data.layout.device_local_layout)
|
|
|
|
def test_device_put_non_concrete_layout_error(self):
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
|
|
l1 = Layout(DLL.AUTO, SingleDeviceSharding(jax.devices()[0]))
|
|
with self.assertRaisesRegex(
|
|
ValueError, 'sharding and device_local_layout.*should be concrete'):
|
|
jax.device_put(np_inp, l1)
|
|
|
|
l2 = Layout(DLL.AUTO)
|
|
with self.assertRaisesRegex(
|
|
ValueError, 'sharding and device_local_layout.*should be concrete'):
|
|
jax.device_put(np_inp, l2)
|
|
|
|
l3 = Layout(None, SingleDeviceSharding(jax.devices()[0]))
|
|
out = jax.device_put(np_inp, l3)
|
|
self.assertArraysEqual(out, np_inp)
|
|
self.assertTrue(out._committed)
|
|
|
|
def invalid_layout_spec(self):
|
|
x = np.arange(8)
|
|
compiled = jax.jit(lambda x: x).lower(x).compile()
|
|
with self.assertRaisesRegex(
|
|
ValueError, 'Sharding has to be concrete when layout.*'):
|
|
Layout(compiled.output_layouts[0], None)
|
|
|
|
def test_layout_on_sds(self):
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
out_layout = jax.jit(jnp.sin, out_shardings=Layout(DLL.AUTO)).lower(
|
|
arr).compile().output_layouts
|
|
|
|
sds = jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=out_layout)
|
|
arg_layout, _ = jax.jit(lambda x: x * 2).lower(sds).compile().input_layouts
|
|
self.assertEqual(arg_layout[0], out_layout)
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
'DeviceLocalLayout.AUTO` cannot be used in place of a device-local'
|
|
' layout in a `ShapeDtypeStruct`'):
|
|
jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=Layout(DLL.AUTO))
|
|
|
|
def test_make_array_from_callback(self):
|
|
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
np_inp = np.arange(16).reshape(8, 2)
|
|
sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s)
|
|
|
|
layout = jax.jit(lambda x: x * 2).lower(sds).compile().output_layouts
|
|
|
|
out = jax.make_array_from_callback(np_inp.shape, layout,
|
|
lambda idx: np_inp[idx])
|
|
self.assertArraysEqual(out, np_inp)
|
|
self.assertEqual(out.layout, layout)
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
'`DeviceLocalLayout.AUTO` cannot be used in place of a device-local'
|
|
' layout'):
|
|
jax.make_array_from_callback(np_inp.shape, Layout(DLL.AUTO, s),
|
|
lambda idx: np_inp[idx])
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError, 'sharding should be an instance of `jax.sharding`'):
|
|
jax.make_array_from_callback(
|
|
np_inp.shape, Layout(None, None), lambda idx: np_inp[idx])
|
|
|
|
def test_wsc_concrete_layout(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
shape = (16, 128)
|
|
s = NamedSharding(mesh, P('x'))
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
# Create a custom layout instead of using `arr.layout` to test the API.
|
|
custom_dll = DLL(major_to_minor=(0, 1))
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
y = x.T
|
|
# Constrain `y` to the original layout of `arr` because without it,
|
|
# the layout of `y` would be the transpose of `arr`.
|
|
return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s))
|
|
|
|
out = f(arr)
|
|
self.assertEqual(out.layout.device_local_layout.major_to_minor,
|
|
custom_dll.major_to_minor)
|
|
self.assertEqual(out.layout, arr.layout)
|
|
self.assertArraysEqual(out, np_inp.T)
|
|
|
|
def test_wsc_bfloat16_concrete_layout(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
shape = (16, 128)
|
|
s = NamedSharding(mesh, P('x'))
|
|
inp = jnp.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape)
|
|
arr = jax.device_put(inp, s)
|
|
|
|
# Create a custom layout instead of using `arr.layout` to test the API.
|
|
custom_dll = DLL(major_to_minor=(0, 1))
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
y = x.T
|
|
# Constrain `y` to the original layout of `arr` because without it,
|
|
# the layout of `y` would be the transpose of `arr`.
|
|
return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s))
|
|
|
|
out = f(arr)
|
|
self.assertEqual(out.layout.device_local_layout.major_to_minor,
|
|
custom_dll.major_to_minor)
|
|
self.assertEqual(out.layout, arr.layout)
|
|
self.assertArraysEqual(out, inp.T)
|
|
|
|
def test_device_put_user_concrete_layout(self):
|
|
shape = (8, 128)
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
dll = DLL(major_to_minor=(1, 0))
|
|
s = SingleDeviceSharding(jax.devices()[0])
|
|
|
|
out = jax.device_put(np_inp, Layout(dll, s))
|
|
self.assertEqual(out.layout.device_local_layout.major_to_minor,
|
|
dll.major_to_minor)
|
|
self.assertArraysEqual(out, np_inp)
|
|
|
|
def test_device_put_user_concrete_layout_multi_device(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
shape = (16, 128)
|
|
s = NamedSharding(mesh, P('x'))
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
jnp_inp = jnp.arange(math.prod(shape)).reshape(shape)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
custom_layout = Layout(DLL(major_to_minor=(0, 1)), s)
|
|
out1 = jax.device_put(arr, custom_layout)
|
|
|
|
with jax.sharding.use_mesh(mesh):
|
|
out2 = jax.device_put(arr, custom_layout)
|
|
out3 = jax.device_put(jnp_inp, custom_layout)
|
|
out4 = jax.device_put(np_inp, custom_layout)
|
|
|
|
for o in [out1, out2, out3, out4]:
|
|
self.assertArraysEqual(o, np_inp)
|
|
self.assertEqual(o.layout.device_local_layout.major_to_minor,
|
|
custom_layout.device_local_layout.major_to_minor)
|
|
|
|
def test_concrete_layout_jit(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
shape = (16, 128)
|
|
s = NamedSharding(mesh, P('x'))
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
def f(x):
|
|
return x.T
|
|
|
|
custom_dll = DLL(major_to_minor=(0, 1))
|
|
f = jax.jit(f, out_shardings=Layout(custom_dll, s))
|
|
|
|
out = f(arr)
|
|
self.assertArraysEqual(out, np_inp.T)
|
|
self.assertEqual(out.layout.device_local_layout.major_to_minor,
|
|
custom_dll.major_to_minor)
|
|
|
|
def test_compatible_aval_error(self):
|
|
custom_dll = DLL(major_to_minor=(0, 1, 2))
|
|
l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0]))
|
|
inp = np.arange(8)
|
|
|
|
@partial(jax.jit, in_shardings=l)
|
|
def f(x):
|
|
return x * 2
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'.*Length of major_to_minor and the rank of the value should match.*'):
|
|
f(inp)
|
|
|
|
def test_incompatible_aval_error_device_put(self):
|
|
custom_dll = DLL(major_to_minor=(0, 1, 2))
|
|
l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0]))
|
|
inp = np.arange(8)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'.*Length of major_to_minor and the rank of the value should match.*'):
|
|
jax.device_put(inp, l)
|
|
|
|
def test_concrete_layout_in_shardings(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
shape = (16, 128)
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
custom_dll = DLL(major_to_minor=(0, 1))
|
|
|
|
@partial(jax.jit,
|
|
in_shardings=Layout(custom_dll, s),
|
|
out_shardings=Layout(DLL.AUTO))
|
|
def f(x):
|
|
return x.T
|
|
|
|
out = f(arr)
|
|
self.assertArraysEqual(out, np_inp.T)
|
|
self.assertEqual(out.layout.device_local_layout.major_to_minor,
|
|
custom_dll.major_to_minor[::-1])
|
|
|
|
custom_dll2 = DLL(major_to_minor=(1, 0))
|
|
|
|
@partial(jax.jit, in_shardings=Layout(custom_dll2, s))
|
|
def g(x):
|
|
return x.T
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'Layout passed to jit does not match the layout on the respective arg'):
|
|
g(arr)
|
|
|
|
def test_in_layouts_jit_jnp_input(self):
|
|
major_last_layout = DLL(major_to_minor=(1, 0))
|
|
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
|
|
|
f = jax.jit(lambda x: x + 1,
|
|
in_shardings=Layout(major_last_layout, sharding))
|
|
|
|
arr = jnp.arange(8 * 128).reshape(8, 128)
|
|
out = f(arr)
|
|
self.assertArraysEqual(out, arr + 1)
|
|
|
|
# cpp dispatch should call into shard_args from cpp.
|
|
out2 = f(arr)
|
|
self.assertArraysEqual(out2, arr + 1)
|
|
|
|
np_inp = np.arange(8 * 128).reshape(8, 128)
|
|
out3 = f(np_inp)
|
|
self.assertArraysEqual(out3, np_inp + 1)
|
|
|
|
# cpp dispatch should call into shard_args from cpp.
|
|
out4 = f(np_inp)
|
|
self.assertArraysEqual(out4, np_inp + 1)
|
|
|
|
def test_layout_donation(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
shape = (16, 128)
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
|
|
custom_dll = DLL(major_to_minor=(0, 1))
|
|
arr = jax.device_put(np_inp, Layout(custom_dll, s))
|
|
|
|
@partial(jax.jit, in_shardings=Layout(custom_dll, s), donate_argnums=0)
|
|
def f(x):
|
|
return x
|
|
|
|
f(arr)
|
|
self.assertTrue(arr.is_deleted())
|
|
|
|
def test_layout_donation_auto(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
shape = (128, 16)
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
@partial(jax.jit, out_shardings=Layout(DLL.AUTO), donate_argnums=0)
|
|
def f(x):
|
|
return x * x
|
|
|
|
f(arr)
|
|
self.assertTrue(arr.is_deleted())
|
|
|
|
def test_layout_donation_matching_in_and_out(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
shape = (128, 16)
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
|
|
custom_dll = DLL(major_to_minor=(0, 1))
|
|
l = Layout(custom_dll, s)
|
|
arr = jax.device_put(np_inp, l)
|
|
|
|
@partial(jax.jit, in_shardings=l, out_shardings=l, donate_argnums=0)
|
|
def f(x):
|
|
return x * x
|
|
|
|
f(arr)
|
|
self.assertTrue(arr.is_deleted())
|
|
|
|
@jtu.skip_on_devices('cpu', 'gpu')
|
|
def test_layout_donation_mismatching_in_and_out_fails(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
shape = (16*2, 32016*2)
|
|
np_inp = np.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape)
|
|
|
|
custom_dll1 = DLL(major_to_minor=(1, 0), _tiling=((8,128), (2,1)))
|
|
l1 = Layout(custom_dll1, s)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
@partial(jax.jit, out_shardings=l1, donate_argnums=0)
|
|
def f(x):
|
|
return x * x
|
|
|
|
sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s)
|
|
f.lower(sds).compile()(arr)
|
|
self.assertFalse(arr.is_deleted())
|
|
|
|
def test_donation_error_on_auto(self):
|
|
@partial(jax.jit, donate_argnums=0, in_shardings=Layout(DLL.AUTO))
|
|
def f(x):
|
|
return x * 2
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, ".*Did you mean to set the.*output layout.*AUTO.*"):
|
|
f(jnp.arange(8))
|
|
|
|
@partial(jax.jit, donate_argnums=0, out_shardings=Layout(DLL.AUTO))
|
|
def g(x):
|
|
return x * 2
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, ".*Did you mean to set the.*input layout.*AUTO.*"):
|
|
g(jnp.arange(8))
|
|
|
|
def test_sparsecore_compute(self):
|
|
if not (jax.devices()[0].device_kind == 'TPU v5' or
|
|
jtu.is_device_tpu_at_least(6)):
|
|
self.skipTest('Does not have a sparsecore present')
|
|
shape = (128, 128)
|
|
inp = jnp.arange(math.prod(shape)).reshape(shape)
|
|
|
|
dll = DLL(major_to_minor=(0, 1), _tiling=((8,),))
|
|
s = SingleDeviceSharding(jax.devices()[0])
|
|
sparse_layout = Layout(dll, s)
|
|
sparecore_arr = jax.device_put(inp, sparse_layout)
|
|
dense_layout = Layout(DLL(major_to_minor=(0, 1)), s)
|
|
|
|
@compute_on('tpu_sparsecore')
|
|
@jax.jit
|
|
def sparsecore_compute(x):
|
|
return x * x
|
|
|
|
@partial(jax.jit, out_shardings=(dense_layout, sparse_layout))
|
|
def f(x, y):
|
|
return x * 2, sparsecore_compute(y)
|
|
|
|
f(inp, sparecore_arr)
|
|
|
|
def test_sparsecore_compute_twice(self):
|
|
if not (
|
|
jax.devices()[0].device_kind == 'TPU v5'
|
|
or jtu.is_device_tpu_at_least(6)
|
|
):
|
|
self.skipTest('Does not have a sparsecore present')
|
|
shape = (4096, 8)
|
|
inp = jnp.arange(math.prod(shape)).reshape(shape)
|
|
|
|
dll = DLL(major_to_minor=(0, 1), _tiling=((8,),))
|
|
s = SingleDeviceSharding(jax.devices()[0])
|
|
sparse_layout = Layout(dll, s)
|
|
sparecore_arr = jax.device_put(inp, sparse_layout)
|
|
|
|
@compute_on('tpu_sparsecore')
|
|
@jax.jit
|
|
def sparsecore_multiply(x, y):
|
|
return x * y
|
|
|
|
@compute_on('tpu_sparsecore')
|
|
@jax.jit
|
|
def sparsecore_add(x, y):
|
|
return x + y
|
|
|
|
@partial(jax.jit, donate_argnums=0, out_shardings=sparse_layout)
|
|
def f(x):
|
|
return sparsecore_multiply(sparsecore_add(x, x) + 1, x)
|
|
|
|
f(sparecore_arr)
|
|
|
|
def test_sparsecore_and_host_compute(self):
|
|
if not (
|
|
jax.devices()[0].device_kind == 'TPU v5'
|
|
or jtu.is_device_tpu_at_least(6)
|
|
):
|
|
self.skipTest('Does not have a sparsecore present')
|
|
shape = (128, 128)
|
|
inp = jnp.arange(math.prod(shape)).reshape(shape)
|
|
s = SingleDeviceSharding(jax.devices()[0])
|
|
|
|
sparse_dll = DLL(major_to_minor=(0, 1), _tiling=((8,),))
|
|
sparse_layout = Layout(sparse_dll, s)
|
|
sparecore_arr = jax.device_put(inp, sparse_layout)
|
|
|
|
host_dll = DLL(major_to_minor=(0, 1), _tiling=((1,),))
|
|
host_layout = Layout(host_dll, s)
|
|
host_arr = jax.device_put(inp, host_layout)
|
|
|
|
@compute_on('tpu_sparsecore')
|
|
@jax.jit
|
|
def sparsecore_compute(x):
|
|
return x * x
|
|
|
|
@compute_on('device_host')
|
|
@jax.jit
|
|
def host_compute(x):
|
|
return x + x
|
|
|
|
@partial(
|
|
jax.jit,
|
|
in_shardings=(sparse_layout, host_layout),
|
|
out_shardings=(sparse_layout, host_layout),
|
|
)
|
|
def f(x, y):
|
|
return sparsecore_compute(x), host_compute(y)
|
|
|
|
f(sparecore_arr, host_arr)
|
|
|
|
def test_cpp_layout_cache_miss(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
shape = (16, 16)
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
arr = jax.device_put(np_inp, s)
|
|
|
|
arr_m2m = arr.layout.device_local_layout.major_to_minor
|
|
custom_layout = Layout(DLL(major_to_minor=arr_m2m[::-1]), s)
|
|
arr2 = jax.device_put(np_inp, custom_layout)
|
|
|
|
@jax.jit
|
|
def f(x):
|
|
return x @ x.T
|
|
|
|
with jtu.count_pjit_cpp_cache_miss() as count:
|
|
out = f(arr)
|
|
out2 = f(arr2)
|
|
self.assertEqual(count(), 2)
|
|
|
|
self.assertArraysEqual(out, np_inp @ np_inp.T)
|
|
self.assertArraysEqual(out2, np_inp @ np_inp.T)
|
|
|
|
def test_layout_donation_with_default_layout(self):
|
|
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
|
s = NamedSharding(mesh, P('x', 'y'))
|
|
shape = (16, 16)
|
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
|
arr = jax.device_put(np_inp, s)
|
|
out_layout = Layout(arr.layout.device_local_layout, s)
|
|
|
|
@partial(jax.jit, out_shardings=out_layout, donate_argnums=0)
|
|
def f(x):
|
|
return x * 2
|
|
|
|
lowered_text = f.lower(arr).as_text()
|
|
self.assertIn('tf.aliasing_output = 0', lowered_text)
|
|
self.assertNotIn('jax.buffer_donor', lowered_text)
|
|
|
|
out = f(arr)
|
|
self.assertArraysEqual(out, np_inp * 2)
|
|
self.assertEqual(out.layout, out_layout)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|