mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
1767 lines
70 KiB
Python
1767 lines
70 KiB
Python
# Copyright 2020 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import functools
|
|
import itertools as it
|
|
import os
|
|
import re
|
|
from itertools import product, permutations
|
|
from typing import (Tuple, List, Dict, Generator, Iterator, Union, Optional)
|
|
from unittest import SkipTest
|
|
|
|
import numpy as np
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
from functools import partial
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import jax.scipy as jscipy
|
|
from jax._src import test_util as jtu
|
|
from jax import vmap
|
|
from jax import lax
|
|
from jax import core
|
|
from jax.core import NamedShape
|
|
from jax.experimental import maps
|
|
from jax.experimental import global_device_array
|
|
from jax.experimental.pjit import pjit, with_sharding_constraint
|
|
from jax.experimental.pjit import PartitionSpec as P
|
|
from jax.experimental.maps import xmap, serial_loop, SerialLoop
|
|
from jax.errors import JAXTypeError
|
|
from jax._src.lib import xla_bridge
|
|
from jax._src.lib import xla_client
|
|
from jax._src.util import curry, unzip2, prod, safe_zip
|
|
from jax._src.lax.parallel import pgather
|
|
from jax.interpreters import batching, pxla
|
|
from jax.ad_checkpoint import checkpoint
|
|
|
|
from jax.config import config
|
|
config.parse_flags_with_absl()
|
|
|
|
# TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py
|
|
# Run all tests with 8 CPU devices.
|
|
def setUpModule():
|
|
global prev_xla_flags
|
|
prev_xla_flags = os.getenv("XLA_FLAGS")
|
|
flags_str = prev_xla_flags or ""
|
|
# Don't override user-specified device count, or other XLA flags.
|
|
if "xla_force_host_platform_device_count" not in flags_str:
|
|
os.environ["XLA_FLAGS"] = (flags_str +
|
|
" --xla_force_host_platform_device_count=8")
|
|
# Clear any cached backends so new CPU backend will pick up the env var.
|
|
xla_bridge.get_backend.cache_clear()
|
|
|
|
# Reset to previous configuration in case other test modules will be run.
|
|
def tearDownModule():
|
|
if prev_xla_flags is None:
|
|
del os.environ["XLA_FLAGS"]
|
|
else:
|
|
os.environ["XLA_FLAGS"] = prev_xla_flags
|
|
xla_bridge.get_backend.cache_clear()
|
|
|
|
|
|
# -------------------- Itertools helpers --------------------
|
|
|
|
def partitions(s, k):
|
|
for indices in product(range(k), repeat=len(s)):
|
|
outs = [[] for _ in range(k)]
|
|
for i, elt in zip(indices, s):
|
|
outs[i].append(elt)
|
|
yield outs
|
|
|
|
def powerset(s):
|
|
s = list(s)
|
|
return it.chain.from_iterable(it.combinations(s, r) for r in range(len(s)+1))
|
|
|
|
# -------------------- vmap test helpers --------------------
|
|
|
|
ensure_bdim_p = core.Primitive('ensure_bdim')
|
|
ensure_bdim_p.def_abstract_eval(lambda x, **kwargs: core.raise_to_shaped(x))
|
|
def _ensure_bdim_batcher(axis_size, frame_name, main_type, vals_in, dims_in, axis_name, bdim):
|
|
v, = vals_in
|
|
d, = dims_in
|
|
assert d is not batching.not_mapped
|
|
return jnp.moveaxis(v, d, bdim), bdim
|
|
batching.axis_primitive_batchers[ensure_bdim_p] = _ensure_bdim_batcher
|
|
batching.primitive_batchers[ensure_bdim_p] = lambda v, d: (v[0], d[0])
|
|
core.axis_substitution_rules[ensure_bdim_p] = partial(jax._src.lax.parallel._subst_all_names_in_param,
|
|
'axis_name')
|
|
|
|
def ensure_bdim(x, axis_name, bdim):
|
|
return ensure_bdim_p.bind(x, axis_name=(axis_name,), bdim=bdim)
|
|
|
|
# -------------------- Axis resources generation --------------------
|
|
|
|
AxisResources = Dict[str, Union[str, Tuple[str, ...]]]
|
|
|
|
def schedules(sizes: Dict[str, int]
|
|
) -> Generator[Tuple[AxisResources, jtu.MeshSpec], None, None]:
|
|
"""Test utility generating xmap parallel schedules from logical names & sizes.
|
|
|
|
Args:
|
|
sizes: dict mapping logical axis name to its corresponding size.
|
|
|
|
Returns:
|
|
A generator producing finitely many values, where each value is a pair in
|
|
which the first element is a value suitable for xmap's axis_resources
|
|
argument and the second element is a list of pairs with the first element
|
|
representing a generated physical mesh axis name and the second element
|
|
representing a corresponding generated mesh axis size. The generated mesh
|
|
names/sizes can be used to define a physical mesh in tests.
|
|
|
|
This function doesn't generate schedules which map distinct logical axis names
|
|
to the same parallel resource name. It only generates parallel resources; the
|
|
rest are implicitly left for vectorization. Parallel resource names are
|
|
generated by prepending an 'r', 'r1', or 'r2' to the corresponding logical
|
|
name.
|
|
|
|
Examples:
|
|
>>> for sched in schedules({'i': 2, 'j': 4}):
|
|
... print(sched)
|
|
({}, [])
|
|
({'i': 'ri'}, [('ri', 1)])
|
|
({'i': 'ri'}, [('ri', 2)])
|
|
({'i': ('r1i', 'r2i')}, [('r1i', 1), ('r2i', 1)])
|
|
({'i': ('r1i', 'r2i')}, [('r1i', 1), ('r2i', 2)])
|
|
({'i': ('r1i', 'r2i')}, [('r1i', 2), ('r2i', 1)])
|
|
({'j': 'rj'}, [('rj', 1)])
|
|
({'j': 'rj'}, [('rj', 2)])
|
|
({'j': 'rj'}, [('rj', 4)])
|
|
({'j': ('r1j', 'r2j')}, [('r1j', 1), ('r2j', 1)])
|
|
({'j': ('r1j', 'r2j')}, [('r1j', 1), ('r2j', 2)])
|
|
({'j': ('r1j', 'r2j')}, [('r1j', 1), ('r2j', 4)])
|
|
({'j': ('r1j', 'r2j')}, [('r1j', 2), ('r2j', 1)])
|
|
({'j': ('r1j', 'r2j')}, [('r1j', 2), ('r2j', 2)])
|
|
({'j': ('r1j', 'r2j')}, [('r1j', 4), ('r2j', 1)])
|
|
({'i': 'ri', 'j': 'rj'}, [('ri', 1), ('rj', 1)])
|
|
({'i': 'ri', 'j': 'rj'}, [('ri', 1), ('rj', 2)])
|
|
({'i': 'ri', 'j': 'rj'}, [('ri', 1), ('rj', 4)])
|
|
({'i': 'ri', 'j': 'rj'}, [('ri', 2), ('rj', 1)])
|
|
({'i': 'ri', 'j': 'rj'}, [('ri', 2), ('rj', 2)])
|
|
({'i': 'ri', 'j': 'rj'}, [('ri', 2), ('rj', 4)])
|
|
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 1), ('r2j', 1)])
|
|
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 1), ('r2j', 2)])
|
|
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 1), ('r2j', 4)])
|
|
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 2), ('r2j', 1)])
|
|
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 2), ('r2j', 2)])
|
|
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 4), ('r2j', 1)])
|
|
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 1), ('r2j', 1)])
|
|
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 1), ('r2j', 2)])
|
|
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 1), ('r2j', 4)])
|
|
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 2), ('r2j', 1)])
|
|
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 2), ('r2j', 2)])
|
|
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 4), ('r2j', 1)])
|
|
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 1), ('r1i', 1), ('r2i', 1)])
|
|
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 1), ('r1i', 1), ('r2i', 2)])
|
|
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 1), ('r1i', 2), ('r2i', 1)])
|
|
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 2), ('r1i', 1), ('r2i', 1)])
|
|
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 2), ('r1i', 1), ('r2i', 2)])
|
|
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 2), ('r1i', 2), ('r2i', 1)])
|
|
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 4), ('r1i', 1), ('r2i', 1)])
|
|
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 4), ('r1i', 1), ('r2i', 2)])
|
|
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 4), ('r1i', 2), ('r2i', 1)])
|
|
"""
|
|
def divisors(n: int) -> List[int]:
|
|
return [m for m in range(1, n + 1) if not n % m]
|
|
|
|
def divisors2(n: int) -> Iterator[Tuple[int, int]]:
|
|
for k1 in divisors(n):
|
|
for k2 in divisors(n // k1):
|
|
yield (k1, k2)
|
|
|
|
# choose a subset of logical axis names to map to parallel resources
|
|
for names in powerset(sizes):
|
|
# partition that set of logical axis names into two subsets: one subset to
|
|
# map to one parallel resource axis and a second subset to map to two
|
|
# parallel resource axes.
|
|
for names1, names2 in partitions(names, 2):
|
|
# to avoid generating too many complex cases, we skip generating cases
|
|
# where more than one logical axis name is to be mapped to two parallel
|
|
# resource axes. comment out this line to generate more complex tests.
|
|
if len(names2) > 1: continue
|
|
# make up parallel resource axis names for each logical axis
|
|
axis_resources1 = ((name, 'r' + name) for name in names1)
|
|
axis_resources2 = ((name, ('r1' + name, 'r2' + name)) for name in names2)
|
|
axis_resources = dict(it.chain(axis_resources1, axis_resources2))
|
|
# make up sizes for each resource axis, where the size must divide the
|
|
# corresponding logical axis
|
|
for mesh_sizes1 in product(*(divisors(sizes[n]) for n in names1)):
|
|
for mesh_sizes2 in product(*(divisors2(sizes[n]) for n in names2)):
|
|
mesh_data1 = (('r' + name, size) for name, size in zip(names1, mesh_sizes1))
|
|
mesh_data2 = (pair for name, (size1, size2) in zip(names2, mesh_sizes2)
|
|
for pair in [('r1' + name, size1), ('r2' + name, size2)])
|
|
mesh_data = list(it.chain(mesh_data1, mesh_data2))
|
|
yield axis_resources, mesh_data
|
|
|
|
|
|
class XMapTestCase(jtu.BufferDonationTestCase):
|
|
pass
|
|
|
|
|
|
# A mixin that enables SPMD lowering tests
|
|
class SPMDTestMixin:
|
|
def setUp(self):
|
|
if jtu.device_under_test() not in ['tpu', 'gpu']:
|
|
raise SkipTest
|
|
super().setUp()
|
|
jtu.set_spmd_lowering_flag(True)
|
|
|
|
def tearDown(self):
|
|
jtu.restore_spmd_lowering_flag()
|
|
|
|
|
|
class ManualSPMDTestMixin:
|
|
def setUp(self):
|
|
if jtu.device_under_test() not in ['tpu', 'gpu']:
|
|
raise SkipTest
|
|
if not hasattr(xla_client.OpSharding.Type, "MANUAL"):
|
|
raise SkipTest
|
|
super().setUp()
|
|
jtu.set_spmd_lowering_flag(True)
|
|
jtu.set_spmd_manual_lowering_flag(True)
|
|
|
|
def tearDown(self):
|
|
jtu.restore_spmd_manual_lowering_flag()
|
|
jtu.restore_spmd_lowering_flag()
|
|
|
|
|
|
class XMapTest(XMapTestCase):
|
|
|
|
def testBasic(self):
|
|
local_devices = list(jax.local_devices())
|
|
if len(local_devices) < 4:
|
|
raise SkipTest("Test requires at least 4 local devices")
|
|
def f(a, b):
|
|
return a * 2, b * 4
|
|
devices = np.array(local_devices[:4]).reshape((2, 2))
|
|
with maps.Mesh(devices, ('x', 'y')):
|
|
fm = xmap(f,
|
|
in_axes=({0: 'a', 1: 'b'}, ['c', ...]),
|
|
out_axes=({0: 'a', 1: 'b'}, ['c', ...]),
|
|
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
|
|
ashape = (16, 8, 5)
|
|
a = jnp.arange(np.prod(ashape)).reshape(ashape)
|
|
bshape = (2, 7)
|
|
b = jnp.arange(np.prod(bshape)).reshape(bshape)
|
|
c, d = fm(a, b)
|
|
self.assertAllClose(c, a * 2)
|
|
self.assertAllClose(d, b * 4)
|
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
def testCollectiveReduce(self):
|
|
fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4),
|
|
in_axes=(['a', 'b', ...], {0: 'c'}),
|
|
out_axes=(['b', ...], {0: 'c'}),
|
|
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
|
|
ashape = (16, 8, 5)
|
|
a = jnp.arange(np.prod(ashape)).reshape(ashape)
|
|
bshape = (2, 7)
|
|
b = jnp.arange(np.prod(bshape)).reshape(bshape)
|
|
c, d = fm(a, b)
|
|
self.assertAllClose(c, (a * 2).sum(0))
|
|
self.assertAllClose(d, b * 4)
|
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
def testCollectivePermute2D(self):
|
|
perm = np.array([3, 1, 2, 0])
|
|
x = jnp.arange(4).reshape((2, 2))
|
|
result = xmap(lambda x: lax.pshuffle(x, ('i', 'j'), perm),
|
|
in_axes=['i', 'j', ...],
|
|
out_axes=['i', 'j', ...],
|
|
axis_resources={'i': 'x', 'j': 'y'})(x).reshape((-1,))
|
|
self.assertAllClose(result, perm)
|
|
|
|
def testCollectivePermute1D(self):
|
|
perm = np.array([3, 1, 2, 0])
|
|
x = jnp.arange(4)
|
|
result = xmap(lambda x: lax.pshuffle(x, 'i', perm),
|
|
in_axes=['i', ...],
|
|
out_axes=['i', ...])(x)
|
|
self.assertAllClose(result, perm)
|
|
|
|
def testCollectiveAllGather(self):
|
|
x = jnp.arange(4)
|
|
result = xmap(lambda x: lax.all_gather(x, 'i') + lax.axis_index('i'),
|
|
in_axes=['i', ...], out_axes=['i', ...])(x)
|
|
self.assertAllClose(result, x[jnp.newaxis] + x[jnp.newaxis].T)
|
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
def testOneLogicalTwoMeshAxesBasic(self):
|
|
def f(v):
|
|
return lax.psum(v * 2, 'a'), v * 4
|
|
fm = xmap(f, in_axes=['a', ...], out_axes=({}, {1: 'a'}),
|
|
axis_resources={'a': ('x', 'y')})
|
|
vshape = (4, 5)
|
|
v = jnp.arange(np.prod(vshape)).reshape(vshape)
|
|
ans, ans2 = fm(v)
|
|
self.assertAllClose(ans, (v * 2).sum(0))
|
|
self.assertAllClose(ans2, v.T * 4)
|
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
def testOneLogicalTwoMeshAxesSharding(self):
|
|
def f(v):
|
|
return v * 4
|
|
fxy = xmap(f, in_axes=['a', ...], out_axes={1: 'a'},
|
|
axis_resources={'a': ('x', 'y')})
|
|
fyx = xmap(f, in_axes=['a', ...], out_axes={1: 'a'},
|
|
axis_resources={'a': ('y', 'x')})
|
|
vshape = (4, 5)
|
|
v = jnp.arange(np.prod(vshape)).reshape(vshape)
|
|
zxy = fxy(v)
|
|
self.assertEqual(
|
|
zxy.sharding_spec,
|
|
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
|
|
(pxla.ShardedAxis(0), pxla.ShardedAxis(1))))
|
|
zyx = fyx(v)
|
|
self.assertEqual(
|
|
zyx.sharding_spec,
|
|
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
|
|
(pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
|
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
def testSkipFirstMeshDim(self):
|
|
def run(axis_resources):
|
|
return xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...],
|
|
axis_resources=axis_resources)(jnp.ones((4,)))
|
|
self.assertAllClose(run({'i': 'x'}), run({'i': 'y'}))
|
|
|
|
def testCaching(self):
|
|
def f(x):
|
|
assert python_should_be_executing
|
|
return x * 2
|
|
devices = np.array(jax.local_devices()[:2])
|
|
if devices.size < 2:
|
|
raise SkipTest("Test requires 2 devices")
|
|
x = np.arange(8).reshape((2, 2, 2))
|
|
with maps.Mesh(devices, ('x',)):
|
|
python_should_be_executing = True
|
|
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
|
axis_resources={'a': 'x'})(x)
|
|
python_should_be_executing = False
|
|
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
|
axis_resources={'a': 'x'})(x)
|
|
with maps.Mesh(devices, ('x',)):
|
|
python_should_be_executing = False
|
|
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
|
axis_resources={'a': 'x'})(x)
|
|
|
|
@parameterized.named_parameters(
|
|
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
|
|
for name, mesh, axis_resources in (
|
|
('OneToOne', (('x', 2), ('y', 2)), (('a', 'y'), ('b', 'x'))),
|
|
('Multiple', (('x', 2), ('y', 2), ('z', 2)), (('a', 'y'), ('b', ('x', 'z')))),
|
|
))
|
|
@jtu.with_mesh_from_kwargs
|
|
def testNestedMesh(self, mesh, axis_resources):
|
|
@partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}),
|
|
axis_resources=dict([axis_resources[0]]))
|
|
def f(x):
|
|
y = x * 2
|
|
@partial(xmap, in_axes={0: 'b'}, out_axes=({1: 'b'}, {}),
|
|
axis_resources=dict([axis_resources[1]]))
|
|
def h(y):
|
|
# Multiply by a constant array to better exercise the partial_eval rule
|
|
return jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b'))
|
|
return h(y)
|
|
|
|
xshape = (4, 2, 5)
|
|
x = jnp.arange(np.prod(xshape)).reshape(xshape)
|
|
y = f(x)
|
|
self.assertAllClose(y, ((jnp.sin(x * 2) * np.arange(xshape[-1])[None, None]).transpose((1, 2, 0)), (x * 2).sum((0, 1))))
|
|
self.assertEqual(y[0].sharding_spec.sharding,
|
|
(pxla.Chunked([2]), pxla.NoSharding(), pxla.NoSharding()))
|
|
self.assertEqual(y[0].sharding_spec.mesh_mapping,
|
|
(pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2),) * (len(mesh) - 2))
|
|
if config.experimental_xmap_spmd_lowering:
|
|
hlo = f.lower(x).compiler_ir(dialect="hlo").as_hlo_text()
|
|
# Make sure that there are non-partial sharding specs in the HLO
|
|
self.assertRegex(hlo, r"sharding={devices=\[[0-9,]+\][0-9,]+}")
|
|
|
|
@jtu.with_and_without_mesh
|
|
def testMultipleCalls(self, mesh, axis_resources):
|
|
def f(x, y):
|
|
assert x.shape == y.shape == (3, 5)
|
|
return jnp.tensordot(x, y, axes=([1], [1]))
|
|
|
|
f_mapped = xmap(f,
|
|
in_axes=(['i', ...], ['j', ...]),
|
|
out_axes=['i', 'j', ...],
|
|
axis_resources=dict(axis_resources))
|
|
x = jnp.arange(30).reshape(2, 3, 5)
|
|
expected = jnp.einsum('imk,jnk->ijmn', x, x)
|
|
for i in range(10):
|
|
self.assertAllClose(f_mapped(x, x), expected)
|
|
|
|
@jtu.with_and_without_mesh
|
|
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
|
def testBufferDonation(self, mesh, axis_resources):
|
|
shard = lambda x: x
|
|
if axis_resources:
|
|
shard = xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...],
|
|
axis_resources=dict(axis_resources))
|
|
f = xmap(lambda x, y: x + y * 4,
|
|
in_axes=['i', ...], out_axes=['i', ...],
|
|
axis_resources=dict(axis_resources),
|
|
donate_argnums=0)
|
|
# The multiplications below disable some optimizations that prevent reuse
|
|
x = shard(jnp.zeros((2, 5)) * 4)
|
|
y = shard(jnp.ones((2, 5)) * 2)
|
|
f(x, y)
|
|
self.assertNotDeleted(y)
|
|
self.assertDeleted(x)
|
|
|
|
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
|
|
@jtu.with_mesh([('x', 2)])
|
|
@jtu.ignore_warning(category=UserWarning, # SPMD test generates warning.
|
|
message="Some donated buffers were not usable*")
|
|
def testBufferDonationNamedShape(self):
|
|
axis_resources = {'i': 'x'}
|
|
# named in_aval, unnamed out_aval
|
|
f = xmap(lambda _: jnp.ones((2, 5)),
|
|
in_axes=['i', ...], out_axes=[...],
|
|
axis_resources=axis_resources,
|
|
donate_argnums=0)
|
|
shard = xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...],
|
|
axis_resources=dict(axis_resources))
|
|
x = shard(jnp.zeros((4, 5)))
|
|
f(x)
|
|
self.assertDeleted(x)
|
|
|
|
def testControlFlow(self):
|
|
x = jnp.arange(5)
|
|
xmap(lambda x: lax.fori_loop(0, 10, lambda _, x: lax.psum(x, 'i'), x),
|
|
in_axes=['i', ...], out_axes=['i', ...])(x)
|
|
|
|
@jtu.with_and_without_mesh
|
|
def testAxisSizes(self, mesh, axis_resources):
|
|
result = xmap(lambda: lax.axis_index('i'),
|
|
in_axes=(), out_axes=['i', ...],
|
|
axis_sizes={'i': 6},
|
|
axis_resources=dict(axis_resources))()
|
|
self.assertAllClose(result, jnp.arange(6, dtype=result.dtype))
|
|
|
|
def testCollectiveOverNoName(self):
|
|
result = xmap(lambda: lax.psum(jnp.array(2) ** 2, 'i'),
|
|
in_axes={}, out_axes={}, axis_sizes={'i': 4})()
|
|
self.assertEqual(result, 16)
|
|
|
|
def VmapOfXmapCases(s):
|
|
xmap_in_axes = ([{}] +
|
|
[{i: 'x'} for i in range(3)] +
|
|
[{i: 'x', j: 'y'} for i in range(4) for j in range(4) if i != j])
|
|
for xmap_dim_x, xmap_dim_y in s(product(xmap_in_axes, repeat=2)):
|
|
xmap_axes = sorted(set(xmap_dim_x.values()) | set(xmap_dim_y.values()))
|
|
num_axes = len(xmap_axes)
|
|
if xmap_axes is None:
|
|
continue
|
|
xmap_out_axes = [dict(zip(dims, xmap_axes))
|
|
for dims in permutations(range(2 + num_axes), num_axes)]
|
|
for xmap_dim_z in s(xmap_out_axes):
|
|
for vmap_dim_x in s([*range(2 + len(xmap_dim_x)), None]):
|
|
for vmap_dim_y in s([*range(2 + len(xmap_dim_y)), None]):
|
|
if vmap_dim_x is None and vmap_dim_y is None:
|
|
continue
|
|
for vmap_dim_result in s(range(3)):
|
|
for vmap_dim_z in s(range(2 + len(xmap_axes))):
|
|
for vmap_as_xmap in s([False, True]):
|
|
yield {"testcase_name":
|
|
f"_xin={(sorted(xmap_dim_x.items()), sorted(xmap_dim_y.items()))}_"
|
|
f"xout={sorted(xmap_dim_z.items())}_vin={(vmap_dim_x, vmap_dim_y)}_"
|
|
f"vout={vmap_dim_z}_vresult={vmap_dim_result}_vmap_as_xmap={vmap_as_xmap}",
|
|
"xmap_in_axes": (xmap_dim_x, xmap_dim_y),
|
|
"xmap_out_axes": xmap_dim_z,
|
|
"vmap_in_axes": (vmap_dim_x, vmap_dim_y),
|
|
"vmap_out_axes": vmap_dim_z,
|
|
"vmap_result_axis": vmap_dim_result,
|
|
"vmap_as_xmap": vmap_as_xmap}
|
|
|
|
@parameterized.named_parameters(jtu.named_cases_from_sampler(VmapOfXmapCases))
|
|
def testNestedMap(self,
|
|
xmap_in_axes, xmap_out_axes,
|
|
vmap_in_axes, vmap_out_axes, vmap_result_axis,
|
|
vmap_as_xmap):
|
|
"""Test various vmap(xmap) and xmap(xmap) combinations.
|
|
|
|
The outer map always introduces a single dimension, the inner map introduces one or two.
|
|
"""
|
|
(xin_x, xin_y) = xmap_in_axes
|
|
(vin_x, vin_y) = vmap_in_axes
|
|
vmap_size = 7
|
|
xmap_sizes = {'x': 11, 'y': 13}
|
|
|
|
xshape = [2, 3]
|
|
yshape = [3, 5]
|
|
zshape = [2, 5]
|
|
xind = ['n', 'k']
|
|
yind = ['k', 'm']
|
|
zind = ['n', 'm']
|
|
f = lambda x, y: ensure_bdim(jnp.einsum('nk,km->nm', x, y), 'v', vmap_result_axis)
|
|
|
|
for pos, name in sorted(xin_x.items()):
|
|
xshape.insert(pos, xmap_sizes[name])
|
|
xind.insert(pos, name)
|
|
for pos, name in sorted(xin_y.items()):
|
|
yshape.insert(pos, xmap_sizes[name])
|
|
yind.insert(pos, name)
|
|
for pos, name in sorted(xmap_out_axes.items()):
|
|
zshape.insert(pos, xmap_sizes[name])
|
|
zind.insert(pos, name)
|
|
|
|
if vin_x is not None:
|
|
xshape.insert(vin_x, vmap_size)
|
|
xind.insert(vin_x, 'v')
|
|
if vin_y is not None:
|
|
yshape.insert(vin_y, vmap_size)
|
|
yind.insert(vin_y, 'v')
|
|
zshape.insert(vmap_out_axes, vmap_size)
|
|
zind.insert(vmap_out_axes, 'v')
|
|
|
|
if vmap_as_xmap:
|
|
do_vmap = partial(xmap,
|
|
in_axes=({vin_x: 'v'} if vin_x is not None else {},
|
|
{vin_y: 'v'} if vin_y is not None else {}),
|
|
out_axes={vmap_out_axes: 'v'})
|
|
else:
|
|
do_vmap = partial(vmap, in_axes=vmap_in_axes, out_axes=vmap_out_axes, axis_name='v')
|
|
|
|
fm = do_vmap(xmap(f, in_axes=xmap_in_axes, out_axes=xmap_out_axes))
|
|
fref = partial(jnp.einsum, f"{''.join(xind)},{''.join(yind)}->{''.join(zind)}")
|
|
|
|
rng = self.rng()
|
|
x = rng.randn(*xshape)
|
|
y = rng.randn(*yshape)
|
|
self.assertAllClose(fm(x, y), fref(x, y))
|
|
|
|
def testAutodiffBroadcast(self):
|
|
f = xmap(lambda x, y: jnp.cos(lax.dot(x, jnp.sin(y),
|
|
precision=lax.Precision.HIGHEST)),
|
|
in_axes=(['i', ...], {}), out_axes=['i', ...])
|
|
x = jnp.arange(12, dtype=jnp.float32).reshape((3, 4)) / 100
|
|
y = jnp.arange(20, dtype=jnp.float32).reshape((4, 5)) / 100
|
|
jtu.check_grads(f, (x, y), order=2, modes=['fwd'])
|
|
jtu.check_grads(f, (x, y), order=1, modes=['rev'])
|
|
with self.assertRaises(AssertionError):
|
|
# Second order reverse-mode differentiations seems to be broken,
|
|
# likely due to the transpose of psum being defined incorrectly.
|
|
jtu.check_grads(f, (x, y), order=2, modes=['rev'])
|
|
|
|
def testAutodiffNoBroadcast(self):
|
|
f = xmap(lambda x, y: jnp.cos(lax.dot(x, jnp.sin(y),
|
|
precision=lax.Precision.HIGHEST)),
|
|
in_axes=(['i', ...], [None, 'i']), out_axes=['i'])
|
|
x = jnp.arange(12, dtype=jnp.float32).reshape((3, 4)) / 100
|
|
y = jnp.arange(12, dtype=jnp.float32).reshape((4, 3)) / 100
|
|
jtu.check_grads(f, (x, y), order=2)
|
|
|
|
@jtu.with_and_without_mesh
|
|
def testNamedShape(self, mesh, axis_resources):
|
|
x = np.arange(4,)
|
|
y = 2
|
|
f = xmap(lambda x, y: (x + y, y * lax.axis_index('i')),
|
|
in_axes=(['i', ...], {}),
|
|
out_axes=(['i', ...], ['i', ...]),
|
|
axis_resources=dict(axis_resources))
|
|
z, w = f(x, y)
|
|
self.assertEqual(z.aval.named_shape, {})
|
|
self.assertEqual(w.aval.named_shape, {})
|
|
|
|
@jtu.with_and_without_mesh
|
|
def testBroadcast(self, mesh, axis_resources):
|
|
x = jnp.asarray(2.0)
|
|
f = xmap(lambda x: x, in_axes={}, out_axes=['i'],
|
|
axis_sizes={'i': 4}, axis_resources=dict(axis_resources))
|
|
self.assertAllClose(f(x), jnp.asarray([2.0, 2.0, 2.0, 2.0]))
|
|
|
|
def testNestedBroadcast(self):
|
|
x = jnp.asarray(2.0)
|
|
f = xmap(lambda x: x, in_axes={}, out_axes=['i'], axis_sizes={'i': 4})
|
|
g = xmap(f, in_axes={}, out_axes=['j', ...], axis_sizes={'j': 7})
|
|
self.assertAllClose(g(x), jnp.tile(x.reshape((1, 1)), (7, 4)))
|
|
|
|
@serial_loop('l', 4)
|
|
def testLoopBasic(self):
|
|
x = jnp.arange(16)
|
|
y = xmap(lambda x: x + 4, in_axes=['i'], out_axes=['i'],
|
|
axis_resources={'i': 'l'})(x)
|
|
self.assertAllClose(y, x + 4)
|
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
@serial_loop('l', 4)
|
|
def testLoopWithMesh(self):
|
|
x = jnp.arange(16)
|
|
y = xmap(lambda x: x + 4, in_axes=['i'], out_axes=['i'],
|
|
axis_resources={'i': ('x', 'l')})(x)
|
|
self.assertAllClose(y, x + 4)
|
|
|
|
def testLoopAnonBasic(self):
|
|
x = jnp.arange(16)
|
|
y = xmap(lambda x: x + 4, in_axes=['i'], out_axes=['i'],
|
|
axis_resources={'i': SerialLoop(4)})(x)
|
|
self.assertAllClose(y, x + 4)
|
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
def testLoopAnonWithMesh(self):
|
|
x = jnp.arange(16)
|
|
y = xmap(lambda x: x + 4, in_axes=['i'], out_axes=['i'],
|
|
axis_resources={'i': ('x', SerialLoop(4))})(x)
|
|
self.assertAllClose(y, x + 4)
|
|
|
|
def testLowerWithAbstractArgs(self):
|
|
x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
|
|
# Make sure this doesn't crash
|
|
xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...]).lower(x)
|
|
|
|
def testLowerCompile(self):
|
|
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
|
|
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
|
|
f_exe = f.lower(x).compile()
|
|
self.assertAllClose(f_exe(x), f(x))
|
|
|
|
def testLowerCompileInTreeMismatch(self):
|
|
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
|
|
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
|
|
f_exe = f.lower(x).compile()
|
|
self.assertRaisesRegex(
|
|
TypeError, "function compiled for .*, called with .*",
|
|
lambda: f_exe([x]))
|
|
|
|
def testLowerCompileArgTypeMismatch(self):
|
|
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
|
|
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
|
|
x_f32 = x.astype(jnp.float32)
|
|
x_i32 = x.astype(jnp.int32)
|
|
f_exe = f.lower(x_f32).compile()
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
"Computation compiled for input types:\n.*float32.*\n"
|
|
"called with:\n.*int32.*",
|
|
lambda: f_exe(x_i32))
|
|
|
|
def testLowerCompilerIR(self):
|
|
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
|
|
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
|
|
f = f.lower(x)
|
|
self.assertIsNotNone(f.compiler_ir())
|
|
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
|
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
|
|
|
def testLowerCompileCompilerIR(self):
|
|
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
|
|
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
|
|
f = f.lower(x).compile()
|
|
self.assertIsNotNone(f.compiler_ir())
|
|
|
|
def testLowerCompileExecutable(self):
|
|
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
|
|
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
|
|
f = f.lower(x).compile()
|
|
self.assertIsNotNone(f.runtime_executable())
|
|
|
|
def testNewCheckpoint(self):
|
|
f = checkpoint(xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...]))
|
|
self.assertAllClose(jax.grad(lambda x: f(x).sum())(jnp.arange(3.)), jnp.ones(3))
|
|
|
|
|
|
class XMapTestSPMD(SPMDTestMixin, XMapTest):
|
|
"""Re-executes all basic tests with the SPMD partitioner enabled"""
|
|
|
|
skipped_tests = {
|
|
"CollectivePermute2D" # vmap of multidimensional permute not implemented yet
|
|
}
|
|
|
|
def setUp(self):
|
|
for skipped_name in self.skipped_tests:
|
|
if skipped_name in self._testMethodName:
|
|
raise SkipTest
|
|
super().setUp()
|
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2), ('z', 2)])
|
|
def testNestedMeshSPMD(self):
|
|
h = xmap(lambda y: (jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b', 'c'))),
|
|
in_axes={0: 'c'}, out_axes=({1: 'c'}, {}),
|
|
axis_resources={'c': 'z'})
|
|
f = xmap(lambda x: h(x * 2),
|
|
in_axes=[None, 'a', 'b', ...], out_axes=(['a', 'b', ...], {}),
|
|
axis_resources={'a': 'x', 'b': 'y'})
|
|
xshape = (8, 2, 4, 5)
|
|
x = jnp.arange(np.prod(xshape)).reshape(xshape)
|
|
hlo = f.lower(x).compiler_ir(dialect="hlo").as_hlo_text()
|
|
match = re.search(r"sharding={devices=\[([0-9,]+)\][0-9,]+}", hlo)
|
|
self.assertIsNot(match, None)
|
|
tile_factors = [int(s) for s in match.group(1).split(',')]
|
|
self.assertEqual(set(tile_factors), {1, 2})
|
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
def testFixedSharding(self):
|
|
# TODO(apaszke): Add support for extracting XLA computations generated by
|
|
# xmap and make this less of a smoke test.
|
|
try:
|
|
config.update("experimental_xmap_ensure_fixed_sharding", True)
|
|
f = xmap(lambda x: jnp.sin(2 * jnp.sum(jnp.cos(x) + 4, 'i')),
|
|
in_axes=['i'], out_axes={}, axis_resources={'i': 'x'})
|
|
x = jnp.arange(20, dtype=jnp.float32)
|
|
f(x)
|
|
finally:
|
|
config.update("experimental_xmap_ensure_fixed_sharding", False)
|
|
|
|
|
|
class XMapTestManualSPMD(ManualSPMDTestMixin, XMapTestCase):
|
|
@jtu.with_mesh([('x', 2)])
|
|
def testBasic(self):
|
|
f = lambda x: jnp.sin(jnp.cos(x) + x) * x
|
|
fx = xmap(f, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'x'})
|
|
x = jnp.arange(20, dtype=jnp.float32)
|
|
self.assertAllClose(fx(x), f(x))
|
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
def testReplicated(self):
|
|
# TODO(apaszke): This seems to be failing if I try to have a replicated and a mapped argument?
|
|
f = lambda x: jnp.sin(jnp.cos(x) + x) * x
|
|
fx = xmap(f, in_axes=[...], out_axes=[...], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
|
|
x = jnp.arange(20, dtype=jnp.float32)
|
|
self.assertAllClose(fx(x), f(x))
|
|
|
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
|
def testInPJit(self):
|
|
f = xmap(lambda x: jnp.sin(x) + x, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'x'})
|
|
h = pjit(lambda x: f(x * x) + x, in_axis_resources=P('y'), out_axis_resources=None)
|
|
x = jnp.arange(20, dtype=jnp.float32)
|
|
self.assertAllClose(h(x), jnp.sin(x * x) + x * x + x)
|
|
|
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
|
def testInPJitReplicated(self):
|
|
f = xmap(lambda x: jnp.sin(x) + x, in_axes={}, out_axes={}, axis_sizes={'i': 4}, axis_resources={'i': 'x'})
|
|
h = pjit(lambda x: f(x * x) + x, in_axis_resources=P('y'), out_axis_resources=None)
|
|
x = jnp.arange(20, dtype=jnp.float32)
|
|
self.assertAllClose(h(x), jnp.sin(x * x) + x * x + x)
|
|
|
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
|
def testNestedConstraint(self):
|
|
# TODO(b/219691408): Using P('y') instead of P() causes an XLA crash!
|
|
fimpl = lambda x: with_sharding_constraint(jnp.sin(x), P()) + x
|
|
f = xmap(fimpl, in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'})
|
|
h = pjit(lambda x: f(x * x) + x, in_axis_resources=P('y'), out_axis_resources=None)
|
|
x = jnp.arange(20, dtype=jnp.float32).reshape(4, 5)
|
|
self.assertAllClose(h(x), jnp.sin(x * x) + x * x + x)
|
|
|
|
|
|
class NamedNumPyTest(XMapTestCase):
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": f"_{reduction.__name__}_axes={axes}_i={mapped_axis}",
|
|
"reduction": reduction, "axes": axes, "mapped_axis": mapped_axis}
|
|
for reduction in (jnp.sum, jnp.max, jnp.min, jnp.mean, jnp.var, jnp.std,
|
|
jscipy.special.logsumexp)
|
|
for axes in (0, 'i', (1,), ('i',), (0, 1), (0, 'i'), ('i', 0))
|
|
for mapped_axis in range(3)))
|
|
def testReductions(self, reduction, axes, mapped_axis):
|
|
axes_t = axes if isinstance(axes, tuple) else (axes,)
|
|
ref_red = partial(reduction,
|
|
axis=tuple(mapped_axis if a == 'i' else a + (a >= mapped_axis)
|
|
for a in axes_t))
|
|
mapped_axis_after_red = mapped_axis - sum(axis < mapped_axis if axis != 'i' else 0
|
|
for axis in axes_t)
|
|
xmap_red = xmap(lambda x: reduction(x, axes),
|
|
in_axes={mapped_axis: 'i'},
|
|
out_axes=({} if 'i' in axes_t else {mapped_axis_after_red: 'i'}))
|
|
|
|
rng = self.rng()
|
|
x = rng.randn(2, 5, 6)
|
|
self.assertAllClose(ref_red(x), xmap_red(x))
|
|
|
|
|
|
class NamedRandomTest(XMapTestCase):
|
|
|
|
@curry
|
|
def parameterize_by_sampler(extra, f, subset):
|
|
if extra is None:
|
|
extra = [("", {})]
|
|
else:
|
|
extra = list(extra)
|
|
subset_fn = jtu.cases_from_list if subset else lambda x: x
|
|
return parameterized.named_parameters(subset_fn(
|
|
{"testcase_name": name + extra_name, "distr_sample": sample, **extra_kwargs}
|
|
for name, sample in [
|
|
("Uniform", jax.random.uniform),
|
|
("Normal", jax.random.normal),
|
|
("Bernoulli", partial(jax.random.bernoulli, p=0.5)),
|
|
("TruncatedNormal", partial(jax.random.truncated_normal, lower=-2, upper=2)),
|
|
]
|
|
for extra_name, extra_kwargs in extra))(f)
|
|
|
|
@parameterize_by_sampler(None, subset=False)
|
|
def testSamplerSharding(self, distr_sample):
|
|
def sample(shape, map_size):
|
|
return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=shape),
|
|
in_axes=(), out_axes=[None, 'i', ...], axis_sizes={'i': map_size})()
|
|
replicated = sample((3,), 4)
|
|
self.assertTrue((replicated[:,[0]] == replicated).all())
|
|
sharded = sample(NamedShape(3, i=4), 4)
|
|
self.assertFalse((sharded[:,[0]] == sharded[:,1:]).all(1).any())
|
|
error = "The shape of axis i was specified as 4, but it really is 5"
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
sample(NamedShape(3, i=4), 5)
|
|
|
|
@parameterize_by_sampler(
|
|
((f"_mesh={mesh}_resources={sorted(axis_resources.items())}",
|
|
{"axis_resources": tuple(axis_resources.items()), "mesh": tuple(mesh)})
|
|
for axis_resources, mesh in schedules({'i': 4, 'j': 6})), subset=True)
|
|
@jtu.with_mesh_from_kwargs
|
|
def testSamplerResourceIndependence(self, distr_sample, axis_resources, mesh):
|
|
def sample(axis_resources):
|
|
return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=NamedShape(3, i=4, j=6)),
|
|
in_axes=(), out_axes=['i', 'j', ...], axis_sizes={'i': 4, 'j': 6},
|
|
axis_resources=axis_resources)()
|
|
self.assertAllClose(sample({}), sample(dict(axis_resources)))
|
|
|
|
|
|
class NamedNNTest(XMapTestCase):
|
|
|
|
def testOneHot(self):
|
|
f = xmap(lambda x: jax.nn.one_hot([1, 2, 0], 3, axis='i'),
|
|
in_axes=['i', ...], out_axes=['i', ...])
|
|
expected = jnp.array([[0., 1., 0.],
|
|
[0., 0., 1.],
|
|
[1., 0., 0.]]).T
|
|
self.assertAllClose(f(jnp.ones((3,))), expected)
|
|
|
|
def testOneHotOutOfBound(self):
|
|
f = xmap(lambda x: jax.nn.one_hot([-1, 3], 3, axis='i'),
|
|
in_axes=['i', ...], out_axes=['i', ...])
|
|
self.assertAllClose(f(jnp.ones((3,))), jnp.zeros((3, 2)))
|
|
|
|
def testOneHotAxisSizeMismatch(self):
|
|
f = xmap(lambda x: jax.nn.one_hot([-1, 3], 3, axis='i'),
|
|
in_axes=['i', ...], out_axes=['i', ...])
|
|
with self.assertRaisesRegex(ValueError, "to match the size of axis i, but 3 != 5"):
|
|
f(jnp.ones((5,)))
|
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
{"testcase_name": f"_map_in={map_in}_map_out={map_out}_fan={fan}_distr={distr}",
|
|
"map_in": map_in, "map_out": map_out, "fan": fan,
|
|
"distr": distr}
|
|
for map_in, map_out in [(True, False), (False, True), (True, True)]
|
|
for fan in ['fan_in', 'fan_out', 'fan_avg']
|
|
for distr in ['uniform', 'normal', 'truncated_normal']))
|
|
def testVarianceScaling(self, map_in, map_out, fan, distr):
|
|
shape = (80, 50, 7)
|
|
fan_in, fan_out = jax._src.nn.initializers._compute_fans(
|
|
NamedShape(*shape), 0, 1)
|
|
key = jax.random.PRNGKey(1)
|
|
base_scaling = partial(jax.nn.initializers.variance_scaling, 100, fan, distr)
|
|
ref_sampler = lambda: base_scaling(in_axis=0, out_axis=1)(key, shape)
|
|
if map_in and map_out:
|
|
out_axes = ['i', 'o', ...]
|
|
named_shape = NamedShape(shape[2], i=shape[0], o=shape[1])
|
|
xmap_sampler = lambda: base_scaling(in_axis='i', out_axis='o')(key, named_shape)
|
|
elif map_in:
|
|
out_axes = ['i', ...]
|
|
named_shape = NamedShape(shape[1], shape[2], i=shape[0])
|
|
xmap_sampler = lambda: base_scaling(in_axis='i', out_axis=0)(key, named_shape)
|
|
elif map_out:
|
|
out_axes = [None, 'o', ...]
|
|
named_shape = NamedShape(shape[0], shape[2], o=shape[1])
|
|
xmap_sampler = lambda: base_scaling(in_axis=0, out_axis='o')(key, named_shape)
|
|
mapped_sampler = xmap(xmap_sampler,
|
|
in_axes=(), out_axes=out_axes,
|
|
axis_sizes={'i': shape[0], 'o': shape[1]})
|
|
self.assertAllClose(jnp.var(mapped_sampler()), jnp.var(ref_sampler()),
|
|
atol=1e-4, rtol=2e-2)
|
|
|
|
|
|
class XMapGDATest(XMapTestCase):
|
|
|
|
@jtu.with_mesh([('x', 4), ('y', 2)])
|
|
def test_basic(self):
|
|
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
|
global_input_shape = (8, 2)
|
|
mesh_axes = P('x', 'y')
|
|
input_data = np.arange(
|
|
prod(global_input_shape)).reshape(global_input_shape)
|
|
def cb(index):
|
|
return input_data[index]
|
|
|
|
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
|
|
global_input_shape, global_mesh, mesh_axes, cb)
|
|
|
|
with jax._src.config.parallel_functions_output_gda(True):
|
|
f = maps.xmap(
|
|
lambda x: x,
|
|
in_axes=({0: "a", 1: "b"}),
|
|
out_axes=({0: "a", 1: "b"}),
|
|
axis_resources={"a": "x", "b": "y"})
|
|
|
|
out = f(gda_obj)
|
|
self.assertIsInstance(out, global_device_array.GlobalDeviceArray)
|
|
self.assertEqual(out.shape, (8, 2))
|
|
self.assertEqual(out.local_shards[0].data.shape, (2, 1))
|
|
self.assertDictEqual(out.mesh.shape, {'x': 4, 'y': 2})
|
|
for s in out.local_shards:
|
|
self.assertArraysEqual(s.data, input_data[s.index])
|
|
|
|
@jtu.with_mesh([('x', 4), ('y', 2)])
|
|
def test_xmap_gda_mixed_inputs(self):
|
|
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
|
global_input_shape = (8, 2)
|
|
mesh_axes = P('x')
|
|
input_data = np.arange(
|
|
prod(global_input_shape)).reshape(global_input_shape)
|
|
def cb(index):
|
|
return input_data[index]
|
|
|
|
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
|
|
global_input_shape, global_mesh, mesh_axes, cb)
|
|
|
|
with jax._src.config.parallel_functions_output_gda(True):
|
|
f = maps.xmap(
|
|
lambda x, y: (x @ x.T, y @ y.T),
|
|
in_axes=({0: "a"}, ["c", ...]),
|
|
out_axes=({0: "a"}, ["c", ...]),
|
|
axis_resources={"a": "x", "c": "x"})
|
|
|
|
expected_matrix_mul = np.diagonal(input_data @ input_data.T)
|
|
out1, out2 = f(gda_obj, input_data)
|
|
|
|
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
|
|
self.assertEqual(out1.shape, (8,))
|
|
self.assertEqual(out1.local_shards[0].data.shape, (2,))
|
|
self.assertDictEqual(out1.mesh.shape, {'x': 4, 'y': 2})
|
|
for s in out1.local_shards:
|
|
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
|
|
|
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
|
|
self.assertEqual(out2.shape, (8,))
|
|
self.assertEqual(out2.local_shards[0].data.shape, (2,))
|
|
self.assertDictEqual(out2.mesh.shape, {'x': 4, 'y': 2})
|
|
for s in out2.local_shards:
|
|
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
|
|
|
for i, j in safe_zip(out1.local_shards, out2.local_shards):
|
|
self.assertArraysEqual(i.data, j.data)
|
|
|
|
@jtu.with_mesh([('x', 4), ('y', 2)])
|
|
def test_xmap_gda_double_input(self):
|
|
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
|
global_input_shape = (8, 2)
|
|
input_data = np.arange(
|
|
prod(global_input_shape)).reshape(global_input_shape)
|
|
def cb(index):
|
|
return input_data[index]
|
|
|
|
gda_obj1 = global_device_array.GlobalDeviceArray.from_callback(
|
|
global_input_shape, global_mesh, P('x'), cb)
|
|
gda_obj2 = global_device_array.GlobalDeviceArray.from_callback(
|
|
global_input_shape, global_mesh, P('y'), cb)
|
|
|
|
with jax._src.config.parallel_functions_output_gda(True):
|
|
f = maps.xmap(
|
|
lambda x, y: (x @ x.T, y @ y.T),
|
|
in_axes=({0: "a"}, ["c", ...]),
|
|
out_axes=({0: "a"}, ["c", ...]),
|
|
axis_resources={"a": "x", "c": "y"})
|
|
|
|
expected_matrix_mul = np.diagonal(input_data @ input_data.T)
|
|
out1, out2 = f(gda_obj1, gda_obj2)
|
|
|
|
self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
|
|
self.assertEqual(out1.shape, (8,))
|
|
self.assertEqual(out1.local_shards[0].data.shape, (2,))
|
|
self.assertDictEqual(out1.mesh.shape, {'x': 4, 'y': 2})
|
|
for s in out1.local_shards:
|
|
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
|
|
|
self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
|
|
self.assertEqual(out2.shape, (8,))
|
|
self.assertEqual(out2.local_shards[0].data.shape, (4,))
|
|
self.assertDictEqual(out2.mesh.shape, {'x': 4, 'y': 2})
|
|
for s in out2.local_shards:
|
|
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
|
|
|
@jtu.with_mesh([('x', 4), ('y', 2)])
|
|
def test_xmap_gda_sharding_mismatch(self):
|
|
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
|
global_input_shape = (8, 2)
|
|
mesh_axes = P('x', 'y')
|
|
input_data = np.arange(
|
|
prod(global_input_shape)).reshape(global_input_shape)
|
|
def cb(index):
|
|
return input_data[index]
|
|
|
|
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
|
|
global_input_shape, global_mesh, mesh_axes, cb)
|
|
|
|
with jax._src.config.parallel_functions_output_gda(True):
|
|
f = maps.xmap(
|
|
lambda x: x @ x.T,
|
|
in_axes=({0: "a"}),
|
|
out_axes=({0: "a"}),
|
|
axis_resources={"a": "x"})
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
('Got an input GDA to xmap with different partitioning than '
|
|
'specified in xmap. The partitioning must match.')):
|
|
f(gda_obj)
|
|
|
|
|
|
class NewPrimitiveTest(XMapTestCase):
|
|
|
|
def testGatherPositional(self):
|
|
x = jnp.arange(27).reshape((9, 3))
|
|
idx = jnp.array([1, 2, 1, 0]).reshape((2, 2))
|
|
self.assertAllClose(pgather(x, idx, 0), x[idx.ravel()].reshape((2, 2, 3)))
|
|
|
|
x_explode = x.reshape((3, 3, 3))
|
|
self.assertAllClose(pgather(x, idx, 0), pgather(x_explode, idx, (0, 1)))
|
|
|
|
@jtu.with_and_without_mesh
|
|
def testGather(self, mesh, axis_resources):
|
|
if axis_resources and not config.experimental_xmap_spmd_lowering:
|
|
raise SkipTest("pgather over mesh axes without SPMD lowering not implemented")
|
|
x = jnp.arange(12, dtype=np.float32).reshape((4, 3))
|
|
y = jnp.arange(35).reshape((5, 7)) % 3
|
|
f = xmap(lambda src, idx: pgather(src, idx, 'j'),
|
|
in_axes=(['i', 'j'], ['k', 'm']),
|
|
out_axes=['i', 'k', 'm'],
|
|
axis_resources=dict(axis_resources))
|
|
f_ref = lambda x, y: x[:, y.reshape((-1,))].reshape((4, 5, 7))
|
|
self.assertAllClose(f(x, y), f_ref(x, y))
|
|
|
|
|
|
class NewPrimitiveTestSPMD(SPMDTestMixin, NewPrimitiveTest):
|
|
pass
|
|
|
|
|
|
AxisIndices = Tuple[int, ...]
|
|
MatchedAxisIndices = Tuple[AxisIndices, AxisIndices]
|
|
AxisNames = Tuple[str, ...]
|
|
|
|
class PdotTestSpec:
|
|
# The axis indices stored by a PdotTestSpec are all positional indices
|
|
# *before* taking mapping into account.
|
|
map_cont: MatchedAxisIndices
|
|
pos_cont: MatchedAxisIndices
|
|
map_batch: MatchedAxisIndices
|
|
pos_batch: MatchedAxisIndices
|
|
all_names: AxisNames
|
|
contract_names: AxisNames
|
|
batch_names: AxisNames
|
|
|
|
def __init__(self, map_cont, pos_cont, map_batch, pos_batch):
|
|
self.map_cont = map_cont
|
|
self.pos_cont = pos_cont
|
|
self.map_batch = map_batch
|
|
self.pos_batch = pos_batch
|
|
|
|
names = gen_axis_names()
|
|
self.contract_names = [next(names) for _ in range(len(map_cont[0]))]
|
|
self.batch_names = [next(names) for _ in range(len(map_batch[0]))]
|
|
self.all_names = self.contract_names + self.batch_names
|
|
|
|
@property
|
|
def dot_general_dim_nums(self):
|
|
lhs_contract = (*self.map_cont[0], *self.pos_cont[0])
|
|
rhs_contract = (*self.map_cont[1], *self.pos_cont[1])
|
|
lhs_batch = (*self.map_batch[0], *self.pos_batch[0])
|
|
rhs_batch = (*self.map_batch[1], *self.pos_batch[1])
|
|
return (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
|
|
|
|
@property
|
|
def pos_contract_after_mapping(self):
|
|
lhs = [i - sum(j < i for j in self._lhs_mapped) for i in self.pos_cont[0]]
|
|
rhs = [i - sum(j < i for j in self._rhs_mapped) for i in self.pos_cont[1]]
|
|
return (lhs, rhs)
|
|
|
|
@property
|
|
def pos_batch_after_mapping(self):
|
|
lhs = [i - sum(j < i for j in self._lhs_mapped) for i in self.pos_batch[0]]
|
|
rhs = [i - sum(j < i for j in self._rhs_mapped) for i in self.pos_batch[1]]
|
|
return (lhs, rhs)
|
|
|
|
@property
|
|
def _lhs_mapped(self):
|
|
return {*self.map_cont[0], *self.map_batch[0]}
|
|
|
|
@property
|
|
def _rhs_mapped(self):
|
|
return {*self.map_cont[1], *self.map_batch[1]}
|
|
|
|
@property
|
|
def lhs_in_axes(self):
|
|
axis_indices = [*self.map_cont[0], *self.map_batch[0]]
|
|
return dict(zip(axis_indices, self.all_names))
|
|
|
|
@property
|
|
def rhs_in_axes(self):
|
|
axis_indices = [*self.map_cont[1], *self.map_batch[1]]
|
|
return dict(zip(axis_indices, self.all_names))
|
|
|
|
def all_pdot_specs(lhs_shape, rhs_shape):
|
|
for matching in axis_matchings(lhs_shape, rhs_shape):
|
|
for lists in partitions(matching, 4):
|
|
yield PdotTestSpec(*map(unzip2, lists))
|
|
|
|
def axis_matchings(lhs_shape, rhs_shape):
|
|
def helper(start, exc1, exc2):
|
|
yield ()
|
|
for i in range(start, len(lhs_shape)):
|
|
d1 = lhs_shape[i]
|
|
if i not in exc1:
|
|
for j, d2 in enumerate(rhs_shape):
|
|
if d1 == d2 and j not in exc2:
|
|
for matches in helper(i + 1, exc1 | {i}, exc2 | {j}):
|
|
yield ((i, j), *matches)
|
|
return helper(0, set(), set())
|
|
|
|
def gen_axis_names():
|
|
names = 'ijkl'
|
|
for n in it.count(1):
|
|
for chars in product(names, repeat=n):
|
|
yield ''.join(chars)
|
|
|
|
|
|
def schedules_from_pdot_spec(
|
|
spec: PdotTestSpec, lhs_shape: Tuple[int], rhs_shape: Tuple[int]
|
|
) -> Generator[Tuple[AxisResources, jtu.MeshSpec], None, None]:
|
|
logical_sizes = {
|
|
name: shape[ax]
|
|
for shape, in_axes in [(lhs_shape, spec.lhs_in_axes),
|
|
(rhs_shape, spec.rhs_in_axes)]
|
|
for ax, name in in_axes.items()}
|
|
yield from schedules(logical_sizes)
|
|
|
|
|
|
class PDotTests(XMapTestCase):
|
|
|
|
@jtu.with_mesh([('r1', 2)])
|
|
def testPdotBasic(self):
|
|
def f(x, y):
|
|
return lax.pdot(x, y, 'i')
|
|
|
|
f_mapped = xmap(f,
|
|
in_axes=({1: 'i'}, {0: 'i'}),
|
|
out_axes={},
|
|
axis_resources={'i': 'r1'})
|
|
|
|
rng = self.rng()
|
|
x = rng.randn(3, 8)
|
|
y = rng.randn(8, 5)
|
|
|
|
z = f_mapped(x, y)
|
|
|
|
self.assertAllClose(z, jnp.dot(x, y))
|
|
|
|
@jtu.with_mesh([('r1', 2)])
|
|
def testPdotBatching(self):
|
|
def f(x, y):
|
|
return lax.pdot(x, y, 'i')
|
|
|
|
rng = self.rng()
|
|
x = rng.randn(2, 3, 8)
|
|
y = rng.randn(2, 8, 5)
|
|
|
|
f_mapped = xmap(f,
|
|
in_axes=({0: 'j', 2: 'i'}, {0: 'j', 1: 'i'}),
|
|
out_axes=['j', ...],
|
|
axis_resources={'i': 'r1'})
|
|
|
|
z = f_mapped(x, y)
|
|
|
|
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
|
|
|
|
@jtu.with_mesh([('r1', 2)])
|
|
def testPdotBatchingShardUncontractedDim(self):
|
|
def f(x, y):
|
|
return lax.pdot(x, y, 'i')
|
|
|
|
rng = self.rng()
|
|
x = rng.randn(2, 3, 8)
|
|
y = rng.randn(2, 8, 5)
|
|
|
|
f_mapped = xmap(f,
|
|
in_axes=({0: 'j', 2: 'i'}, {0: 'j', 1: 'i'}),
|
|
out_axes=['j', ...],
|
|
axis_resources={'j': 'r1'})
|
|
|
|
z = f_mapped(x, y)
|
|
|
|
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
|
|
|
|
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
|
|
"testcase_name": f"_{next(test_counter)}",
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "pdot_spec": pdot_spec,
|
|
"axis_resources": axis_resources, "mesh_data": mesh_data
|
|
} for test_counter in [it.count()]
|
|
for lhs_shape, rhs_shape in s(product([(2,), (2, 4, 2, 1)], repeat=2))
|
|
for pdot_spec in s(all_pdot_specs(lhs_shape, rhs_shape))
|
|
for axis_resources, mesh_data in s(schedules_from_pdot_spec(
|
|
pdot_spec, lhs_shape, rhs_shape))
|
|
)))
|
|
def testPdotSystematic(self, lhs_shape, rhs_shape, pdot_spec, axis_resources,
|
|
mesh_data):
|
|
rng = jtu.rand_default(self.rng())
|
|
lhs = rng(lhs_shape, np.float32)
|
|
rhs = rng(rhs_shape, np.float32)
|
|
|
|
def pdot_fun(x, y):
|
|
# print(f'pdot(x:{x.aval.str_short()}, y:{y.aval.str_short()},\n'
|
|
# f' axis_name={contract_names},\n'
|
|
# f' pos_contract={spec.pos_contract_after_mapping}\n'
|
|
# f' pos_batch={spec.pos_batch_after_mapping})')
|
|
return jax.lax.pdot(x, y, axis_name=pdot_spec.contract_names,
|
|
pos_batch=pdot_spec.pos_batch_after_mapping,
|
|
pos_contract=pdot_spec.pos_contract_after_mapping)
|
|
|
|
fun = xmap(pdot_fun, in_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes),
|
|
out_axes=[*pdot_spec.batch_names, ...],
|
|
axis_resources=axis_resources)
|
|
|
|
with jtu.with_mesh(mesh_data):
|
|
result = fun(lhs, rhs)
|
|
|
|
expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums)
|
|
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
|
self.assertAllClose(result, expected, check_dtypes=False,
|
|
atol=tol, rtol=tol)
|
|
|
|
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
|
|
"testcase_name": f"_{next(test_counter)}",
|
|
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "pdot_spec": pdot_spec,
|
|
"axis_resources": axis_resources, "mesh_data": mesh_data
|
|
} for test_counter in [it.count()]
|
|
for lhs_shape, rhs_shape in s(product([(2,), (2, 4, 2, 1)], repeat=2))
|
|
for pdot_spec in s(all_pdot_specs(lhs_shape, rhs_shape))
|
|
for axis_resources, mesh_data in s(schedules_from_pdot_spec(
|
|
pdot_spec, lhs_shape, rhs_shape))
|
|
)))
|
|
def testPdotVJPSystematic(self, lhs_shape, rhs_shape, pdot_spec,
|
|
axis_resources, mesh_data):
|
|
rng = jtu.rand_default(self.rng())
|
|
lhs = rng(lhs_shape, np.float32)
|
|
rhs = rng(rhs_shape, np.float32)
|
|
|
|
expected_out, ref_vjp = jax.vjp(
|
|
lambda x, y: lax.dot_general(x, y, pdot_spec.dot_general_dim_nums),
|
|
lhs, rhs)
|
|
out_bar = rng(expected_out.shape, np.float32)
|
|
expected_lhs, expected_rhs = ref_vjp(out_bar)
|
|
|
|
def pdot_fun(x, y, out_bar):
|
|
pdot = partial(jax.lax.pdot,
|
|
axis_name=pdot_spec.contract_names,
|
|
pos_batch=pdot_spec.pos_batch_after_mapping,
|
|
pos_contract=pdot_spec.pos_contract_after_mapping)
|
|
_, pdot_vjp = jax.vjp(pdot, x, y)
|
|
return pdot_vjp(out_bar)
|
|
|
|
fun = xmap(pdot_fun,
|
|
in_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes,
|
|
[*pdot_spec.batch_names, ...]),
|
|
out_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes),
|
|
axis_resources=axis_resources)
|
|
|
|
with jtu.with_mesh(mesh_data):
|
|
lhs_bar, rhs_bar = fun(lhs, rhs, out_bar)
|
|
|
|
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
|
self.assertAllClose(lhs_bar, expected_lhs, check_dtypes=False,
|
|
atol=tol, rtol=tol)
|
|
self.assertAllClose(rhs_bar, expected_rhs, check_dtypes=False,
|
|
atol=tol, rtol=tol)
|
|
|
|
def test_xeinsum_vector_dot(self):
|
|
rng = self.rng()
|
|
x = rng.randn(3)
|
|
y = rng.randn(3)
|
|
out = xmap(partial(jnp.einsum, '{i},{i}->'),
|
|
in_axes=(['i'], ['i']), out_axes=[])(x, y)
|
|
expected = np.einsum('i,i->', x, y)
|
|
self.assertAllClose(out, expected, check_dtypes=False)
|
|
|
|
def test_xeinsum_outer_product(self):
|
|
rng = self.rng()
|
|
x = rng.randn(3)
|
|
y = rng.randn(3)
|
|
out = xmap(partial(jnp.einsum, '{i},{j}->{i,j}'),
|
|
in_axes=(['i'], ['j']), out_axes=['i', 'j'])(x, y)
|
|
expected = np.einsum('i,j->ij', x, y)
|
|
self.assertAllClose(out, expected, check_dtypes=True)
|
|
|
|
def test_xeinsum_matmul(self):
|
|
rng = self.rng()
|
|
x = rng.randn(3, 4)
|
|
y = rng.randn(4, 5)
|
|
|
|
def check(spec):
|
|
out = xmap(partial(jnp.einsum, spec),
|
|
in_axes=(['i', 'j'], ['j', 'k']),
|
|
out_axes=['i', 'k'])(x, y)
|
|
expected = np.einsum('ij,jk->ik', x, y)
|
|
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
|
self.assertAllClose(out, expected, check_dtypes=True,
|
|
atol=tol, rtol=tol)
|
|
check('{i,j},{j,k}->{i,k}')
|
|
check('{i,j},{k,j}->{k,i}') # order of named axes in the spec doesn't matter!
|
|
check('{j},{k,j}->{k}')
|
|
check('{i,j},{j}->{i}')
|
|
check('{j},{j}->{}')
|
|
|
|
def test_xeinsum_no_named_axes_vector_dot(self):
|
|
rng = self.rng()
|
|
x = rng.randn(3)
|
|
y = rng.randn(3)
|
|
out = jnp.einsum('i,i->', x, y, _use_xeinsum=True)
|
|
expected = np.einsum('i,i->', x, y)
|
|
self.assertAllClose(out, expected, check_dtypes=False)
|
|
|
|
def test_xeinsum_no_named_axes_batch_vector_dot(self):
|
|
rng = self.rng()
|
|
x = rng.randn(3, 2)
|
|
y = rng.randn(3, 2)
|
|
out = jnp.einsum('ij,ij->i', x, y, _use_xeinsum=True)
|
|
expected = np.einsum('ij,ij->i', x, y)
|
|
self.assertAllClose(out, expected, check_dtypes=True)
|
|
|
|
def test_xeinsum_no_named_axes_batch_matmul(self):
|
|
rng = np.random.RandomState(0)
|
|
x = rng.randn(3, 5, 4)
|
|
y = rng.randn(3, 4, 2)
|
|
out = jnp.einsum('bij,bjk->bik', x, y, _use_xeinsum=True)
|
|
expected = np.einsum('bij,bjk->bik', x, y)
|
|
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
|
self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol)
|
|
|
|
def test_xeinsum_no_named_axes_reduce_sum(self):
|
|
rng = self.rng()
|
|
x = rng.randn(3)
|
|
y = rng.randn()
|
|
out = jnp.einsum('i,->', x, y, _use_xeinsum=True)
|
|
expected = np.einsum('i,->', x, y)
|
|
self.assertAllClose(out, expected, check_dtypes=True)
|
|
|
|
|
|
def test_xeinsum_no_named_axes_reduce_and_contract(self):
|
|
rng = np.random.RandomState(0)
|
|
x = rng.randn(3, 5, 4)
|
|
y = rng.randn(2, 4, 2)
|
|
out = jnp.einsum('bij,cjk->ik', x, y, _use_xeinsum=True)
|
|
expected = np.einsum('bij,cjk->ik', x, y)
|
|
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
|
self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol)
|
|
|
|
def test_xeinsum_named_axes_reduce(self):
|
|
rng = np.random.RandomState(0)
|
|
x = rng.randn(3, 4)
|
|
y = rng.randn(5,)
|
|
|
|
def check(spec):
|
|
out = xmap(partial(jnp.einsum, spec),
|
|
in_axes=(['i', 'j'], ['k']),
|
|
out_axes=['i', 'k'])(x, y)
|
|
expected = np.einsum('ij,k->ik', x, y)
|
|
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
|
self.assertAllClose(out, expected, check_dtypes=True,
|
|
atol=tol, rtol=tol)
|
|
check('{i,j},{k}->{i,k}')
|
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
def test_xeinsum_named_axes_reduce_with_mesh(self):
|
|
rng = np.random.RandomState(0)
|
|
x = rng.randn(6, 4)
|
|
y = rng.randn(8,)
|
|
|
|
def check(spec):
|
|
out = xmap(partial(jnp.einsum, spec),
|
|
in_axes=(['i', 'j'], ['k']),
|
|
out_axes=['i', 'k'],
|
|
axis_resources={'i': 'x', 'k': 'y'})(x, y)
|
|
expected = np.einsum('ij,k->ik', x, y)
|
|
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
|
self.assertAllClose(out, expected, check_dtypes=True,
|
|
atol=tol, rtol=tol)
|
|
|
|
check('{i,j},{k}->{i,k}')
|
|
check('{i,j},{k}->{k,i}') # order of named axes in the spec doesn't matter!
|
|
check('{j,i},{k}->{i,k}')
|
|
check('{j,i},{k}->{k,i}')
|
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
def test_xeinsum_named_axes_batch_matmul_with_mesh(self):
|
|
rng = np.random.RandomState(0)
|
|
x = rng.randn(8, 3, 4)
|
|
y = rng.randn(8, 4, 5)
|
|
|
|
def check(spec):
|
|
out = xmap(partial(jnp.einsum, spec),
|
|
in_axes=(['b', 'i', 'j'], ['b', 'j', 'k']),
|
|
out_axes=['b', 'i', 'k'],
|
|
axis_resources={'b': 'x', 'j': 'y'})(x, y)
|
|
expected = np.einsum('bij,bjk->bik', x, y)
|
|
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
|
self.assertAllClose(out, expected, check_dtypes=True,
|
|
atol=tol, rtol=tol)
|
|
|
|
check('{b,i,j},{b,j,k}->{b,i,k}')
|
|
check('{j,i,b},{j,b,k}->{i,b,k}') # order of named axes in the spec doesn't matter!
|
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
def test_xeinsum_named_axes_unary_reduce_with_mesh(self):
|
|
rng = np.random.RandomState(0)
|
|
x = rng.randn(8, 6, 4)
|
|
|
|
def check(spec):
|
|
out = xmap(partial(jnp.einsum, spec),
|
|
in_axes=['b', 'i', 'j'],
|
|
out_axes=['b'],
|
|
axis_resources={'b': 'x', 'i': 'y'})(x)
|
|
expected = np.einsum('bij->b', x)
|
|
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
|
self.assertAllClose(out, expected, check_dtypes=True,
|
|
atol=tol, rtol=tol)
|
|
|
|
check('{b,i,j}->{b}')
|
|
check('{b,j,i}->{b}') # order of named axes in the spec doesn't matter!
|
|
check('{i,j,b}->{b}')
|
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
def test_xeinsum_mixed_axes_unary_reduce_with_mesh(self):
|
|
rng = np.random.RandomState(0)
|
|
x = rng.randn(8, 6, 4, 5)
|
|
|
|
def check(spec):
|
|
out = xmap(partial(jnp.einsum, spec),
|
|
in_axes=['b', 'i', ...],
|
|
out_axes=['b', ...],
|
|
axis_resources={'b': 'x', 'i': 'y'})(x)
|
|
expected = np.einsum('bijk->bk', x)
|
|
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
|
self.assertAllClose(out, expected, check_dtypes=True,
|
|
atol=tol, rtol=tol)
|
|
|
|
check('jk{i,b}->k{b}')
|
|
|
|
|
|
class XMapErrorTest(jtu.JaxTestCase):
|
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
def testRepeatedAxisResource(self):
|
|
def f(v):
|
|
return v * 4
|
|
with self.assertRaisesRegex(ValueError, r"distinct resources.*specified \('x', 'x'\) for axis a"):
|
|
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
|
axis_resources={'a': ('x', 'x')})
|
|
|
|
@jtu.with_mesh([('y', 2)])
|
|
def testUndefinedAxisResource(self):
|
|
error = re.escape(
|
|
r"In-scope resources are insufficient to execute the xmapped function. "
|
|
r"The missing resources are: {'x'}")
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
xmap(lambda x: x, in_axes=['a', ...], out_axes=['a', ...],
|
|
axis_resources={'a': 'x'})(jnp.zeros((4,)))
|
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
def testNestedDifferentResources(self):
|
|
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
|
|
def f(x):
|
|
with maps.Mesh(np.empty((), dtype=np.object_), ()):
|
|
@partial(xmap, in_axes={0: 'b'}, out_axes={0: 'b'})
|
|
def h(x):
|
|
return x
|
|
return h(x)
|
|
xshape = (2, 5, 6)
|
|
x = jnp.arange(np.prod(xshape)).reshape(xshape)
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"Changing the physical mesh is not allowed.*"):
|
|
f(x)
|
|
|
|
def testEmptyArgumentTrees(self):
|
|
with self.assertRaisesRegex(ValueError, "Failed to infer size of axes: i."):
|
|
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...])({})
|
|
|
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
|
def testAxesNotDivisibleByResources(self):
|
|
with self.assertRaisesRegex(ValueError, r"Size of axis i \(5\) is not divisible.*"
|
|
r"\(\('x', 'y'\), 4 in total\)"):
|
|
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...],
|
|
axis_sizes={'i': 5}, axis_resources={'i': ('x', 'y')})({})
|
|
|
|
def testInconsistentAxisSizes(self):
|
|
x5 = jnp.arange(5)
|
|
x6 = jnp.arange(6)
|
|
error = (r"The size of axis i was previously inferred to be 5, but found an "
|
|
r"argument of shape \(6,\) with in_axes specification \['i', ...\]. "
|
|
r"Shape mismatch occurs in dimension 0: 6 != 5")
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
xmap(lambda x, y: x, in_axes=(['i', ...], ['i', ...]), out_axes=['i', ...])(x5, x6)
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...], axis_sizes={'i': 5})(x6)
|
|
|
|
def testInAxesRankError(self):
|
|
error = (r"One of xmap arguments has an in_axes specification of \['i', 'j', ...\], "
|
|
r"which implies that it has at least 2 dimensions, but the argument has rank 1")
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
xmap(lambda x: x, in_axes=['i', 'j', ...], out_axes=['j', 'i', ...])(jnp.ones((5,)))
|
|
|
|
def testOutAxesRankError(self):
|
|
error = (r"One of xmap outputs has an out_axes specification of {1: 'i'}, "
|
|
r"which requires the result of the xmapped function to have at least "
|
|
r"1 positional dimensions, but it only has 0")
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
xmap(lambda x: x, in_axes=['i', ...], out_axes={1: 'i'})(jnp.ones((5,)))
|
|
|
|
def testNegativeAxes(self):
|
|
with self.assertRaisesRegex(ValueError, "xmap doesn't support negative axes in in_axes"):
|
|
xmap(lambda x: x, in_axes={-1: 'i'}, out_axes={0: 'i'})(jnp.ones((5,)))
|
|
with self.assertRaisesRegex(ValueError, "xmap doesn't support negative axes in out_axes"):
|
|
xmap(lambda x: x, in_axes={0: 'i'}, out_axes={-1: 'i'})(jnp.ones((5,)))
|
|
|
|
def testDictOutAxes(self):
|
|
# see issue #6410
|
|
out = xmap(lambda x: x, in_axes=[...], out_axes={"a": [...]})({"a": 1})
|
|
self.assertEqual(out, {"a": 1})
|
|
|
|
def testListAxesRankAssertion(self):
|
|
error = (r"xmap argument has an in_axes specification of \['i', None\], which "
|
|
r"asserts that it should be of rank 2, but the argument has rank 1 "
|
|
r"\(and shape \(5,\)\)")
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
xmap(lambda x: x, in_axes=['i', None], out_axes=['i', None])(jnp.ones((5,)))
|
|
error = (r"xmap output has an out_axes specification of \['i', None\], which "
|
|
r"asserts that it should be of rank 2, but the output has rank 3 "
|
|
r"\(and shape \(5, 2, 2\)\)")
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
xmap(lambda x: x.reshape((2, 2)),
|
|
in_axes=['i', None], out_axes=['i', None])(jnp.ones((5, 4)))
|
|
|
|
def testReturnExtraMappedAxes(self):
|
|
fm = xmap(lambda x, y: x + y,
|
|
in_axes=(['a', ...], ['b', ...]), out_axes=['a', ...])
|
|
x = np.arange(12).reshape((4, 3))
|
|
y = np.arange(6).reshape((2, 3))
|
|
error = (r"One of xmap results has an out_axes specification of \['a', ...\], but "
|
|
r"is actually mapped along more axes defined by this xmap call: b")
|
|
with self.assertRaisesRegex(TypeError, error):
|
|
fm(x, y)
|
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
def testResourceConflictArgs(self):
|
|
fm = xmap(lambda x: lax.psum(x, ('a', 'b')),
|
|
in_axes=['a', 'b'], out_axes=[],
|
|
axis_resources={'a': 'x', 'b': 'x'})
|
|
x = np.arange(16).reshape(4, 4)
|
|
error = (r"Axes `a` and `b` are both mapped to the resource `x`, but they "
|
|
r"coincide in the named_shape of an input to an xmapped function "
|
|
r"<lambda>")
|
|
with self.assertRaisesRegex(JAXTypeError, error):
|
|
fm(x)
|
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
def testResourceConflictInner(self):
|
|
fm = xmap(lambda x, y: x + y,
|
|
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
|
|
axis_resources={'a': 'x', 'b': 'x'})
|
|
x = np.arange(12).reshape(4, 3)
|
|
y = np.arange(6).reshape(2, 3)
|
|
error = (r"Axes `a` and `b` are both mapped to the resource `x`, but they "
|
|
r"coincide in the named_shape.*primitive add created at")
|
|
with self.assertRaisesRegex(JAXTypeError, error):
|
|
fm(x, y)
|
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
def testResourceConflictOut(self):
|
|
fm = xmap(lambda x, y: x,
|
|
in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
|
|
axis_resources={'a': 'x', 'b': 'x'})
|
|
x = np.arange(12).reshape(4, 3)
|
|
y = np.arange(6).reshape(2, 3)
|
|
error = (r"One of xmapped function \(<lambda>\) outputs is broadcast along axis "
|
|
r"`b` which is assigned to resources `x`, but the output is already "
|
|
r"partitioned along `x`, because its named shape contains `a`")
|
|
with self.assertRaisesRegex(JAXTypeError, error):
|
|
fm(x, y)
|
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
def testResourceConflictNestArgs(self):
|
|
f = xmap(lambda x: x, in_axes=['i'], out_axes=['i'], axis_resources={'i': 'x'})
|
|
h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
|
|
x = np.arange(16).reshape((4, 4))
|
|
error = (r"Axes `i` and `j` are both mapped to the resource `x`, but they "
|
|
r"coincide in the named_shape of an input to an xmapped function "
|
|
r"<lambda> \(xmap called at .*\)")
|
|
with self.assertRaisesRegex(JAXTypeError, error):
|
|
h(x)
|
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
def testResourceConflictNestInner(self):
|
|
f = xmap(lambda x: lax.axis_index('i') + x,
|
|
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
|
|
h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
|
|
x = np.arange(4)
|
|
error = (r"Axes `i` and `j` are both mapped to the resource `x`, but they "
|
|
r"coincide in the named_shape of a value returned from a primitive "
|
|
r"add created at .*")
|
|
with self.assertRaisesRegex(JAXTypeError, error):
|
|
h(x)
|
|
|
|
@jtu.with_mesh([('x', 2)])
|
|
def testResourceConflictNestOut(self):
|
|
f = xmap(lambda x: x,
|
|
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
|
|
h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
|
|
x = np.arange(4)
|
|
error = (r"One of xmapped function \(<lambda>\) outputs is broadcast along "
|
|
r"axis `i` which is assigned to resources `x`, but the output is "
|
|
r"already partitioned along `x`, because its named shape contains `j`")
|
|
with self.assertRaisesRegex(JAXTypeError, error):
|
|
h(x)
|
|
|
|
@serial_loop('l', 2)
|
|
def testResourceConflictArgsLoop(self):
|
|
fm = xmap(lambda x: x,
|
|
in_axes=['a', 'b'], out_axes=['a', 'b'],
|
|
axis_resources={'a': 'l', 'b': 'l'})
|
|
x = np.arange(16).reshape(4, 4)
|
|
error = (r"Axes `a` and `b` are both mapped to the resource `l`, but they "
|
|
r"coincide in the named_shape of an input to an xmapped function "
|
|
r"<lambda>")
|
|
with self.assertRaisesRegex(JAXTypeError, error):
|
|
fm(x)
|
|
|
|
@serial_loop('l', 2)
|
|
def testLoopCollectives(self):
|
|
fm = xmap(lambda x: lax.psum(x, 'i'),
|
|
in_axes=['i'], out_axes=[],
|
|
axis_resources={'i': 'l'})
|
|
x = np.arange(16)
|
|
error = (r"Named axes with loop resources assigned to them cannot be "
|
|
r"referenced inside the xmapped computation \(e.g. in "
|
|
r"collectives\), but `i` violates that rule")
|
|
with self.assertRaisesRegex(RuntimeError, error):
|
|
fm(x)
|
|
|
|
def testAxesMismatch(self):
|
|
x = jnp.ones((4,))
|
|
p = [['x'], ['x'], ['x']]
|
|
xmap(lambda x: x, (p,), p)([x, x, x]) # OK
|
|
xmap(lambda x: x, [p], p)([x, x, x]) # OK
|
|
error = re.escape(
|
|
r"xmap in_axes specification must be a tree prefix of the "
|
|
r"corresponding value, got specification (['x'], ['x'], ['x']) for value "
|
|
r"tree PyTreeDef((*, *)). Note that xmap in_axes that are "
|
|
r"non-trivial pytrees should always be wrapped in a tuple representing "
|
|
r"the argument list.")
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
xmap(lambda x, y: x, p, p)(x, x) # Error, but make sure we hint at tupling
|
|
# TODO(apaszke): Disable implicit list casts and enable this
|
|
# error = re.escape(
|
|
# r"xmap in_axes specification must be a tree prefix of the "
|
|
# r"corresponding value, got specification (['x'], ['x'], ['x']) for value "
|
|
# r"tree PyTreeDef(([*, *, *],)). Note that xmap in_axes that "
|
|
# r"are non-trivial pytrees should always be wrapped in a tuple representing "
|
|
# r"the argument list. In particular, you're passing in a single argument "
|
|
# r"which means that xmap in_axes might need to be wrapped in a "
|
|
# r"singleton tuple.")
|
|
# with self.assertRaisesRegex(ValueError, error):
|
|
# xmap(lambda x: x, p, p)([x, x, x]) # Error, but make sure we hint at singleton tuple
|
|
error = re.escape(
|
|
r"xmap out_axes specification must be a tree prefix of the "
|
|
r"corresponding value, got specification ([['x'], ['x'], ['x']], ['x']) for "
|
|
r"value tree PyTreeDef([*, *, *]).")
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
xmap(lambda x: x, (p,), (p, ['x']))([x, x, x]) # Error, we raise a generic tree mismatch message
|
|
|
|
|
|
class NamedAutodiffTests(jtu.JaxTestCase):
|
|
|
|
def testVjpReduceAxes(self):
|
|
def f(w, x):
|
|
return jnp.sin(jnp.dot(x, w))
|
|
|
|
def vjp_f(w, x, gy):
|
|
_, pullback = jax.vjp(f, w, x)
|
|
return pullback(gy)
|
|
|
|
def vjp_f_reduced(w, x, gy):
|
|
_, pullback = jax.vjp(f, w, x, reduce_axes=('batch',))
|
|
return pullback(gy)
|
|
|
|
w = np.arange(12, dtype=np.float32).reshape(3, 4)
|
|
x = np.arange(6, dtype=np.float32).reshape(2, 3)
|
|
gy = np.arange(8, dtype=np.float32).reshape(2, 4)
|
|
|
|
# per-example
|
|
error = (r"One of xmap results has an out_axes specification of {}, but is "
|
|
r"actually mapped along more axes defined by this xmap call: "
|
|
r"batch")
|
|
with self.assertRaisesRegex(TypeError, error):
|
|
xmap(vjp_f,
|
|
in_axes=({}, {0: 'batch'}, {0: 'batch'}),
|
|
out_axes=({}, {0: 'batch'}))(w, x, gy)
|
|
out = xmap(vjp_f,
|
|
in_axes=({}, {0: 'batch'}, {0: 'batch'}),
|
|
out_axes=({0: 'batch'}, {0: 'batch'}))(w, x, gy)
|
|
expected = vmap(vjp_f, in_axes=(None, 0, 0), out_axes=(0, 0))(w, x, gy)
|
|
self.assertAllClose(out, expected, check_dtypes=True)
|
|
|
|
# reduced
|
|
out = xmap(vjp_f_reduced,
|
|
in_axes=({}, {0: 'batch'}, {0: 'batch'}),
|
|
out_axes=({}, {0: 'batch'}))(w, x, gy)
|
|
# the reduced VJP is also the VJP when using a positional batch axis
|
|
expected = vjp_f(w, x, gy)
|
|
self.assertAllClose(out, expected, check_dtypes=True)
|
|
|
|
def testVjpReduceAxesCollective(self):
|
|
|
|
# lax.psum has the wrong transpose, so test with a corrected version for now
|
|
@functools.partial(jax.custom_vjp, nondiff_argnums=(1,))
|
|
def psum_idrev(x, axis_name: Optional[AxisNames] = None):
|
|
if axis_name is None:
|
|
return x
|
|
return jax.lax.psum(x, axis_name)
|
|
|
|
def psum_idrev_fwd(x, axis_name):
|
|
return psum_idrev(x, axis_name), None
|
|
|
|
def psum_idrev_bwd(axis_name, res, g):
|
|
del axis_name, res
|
|
return (g,)
|
|
|
|
psum_idrev.defvjp(psum_idrev_fwd, psum_idrev_bwd)
|
|
|
|
def f_named(w, x):
|
|
return psum_idrev(jnp.sin(jnp.dot(x, w)).sum(), 'batch')
|
|
|
|
def f_positional(w, x):
|
|
return jnp.sin(jnp.dot(x, w)).sum()
|
|
|
|
w = np.arange(12, dtype=np.float32).reshape(3, 4)
|
|
x = np.arange(6, dtype=np.float32).reshape(2, 3)
|
|
|
|
# forward
|
|
out = xmap(f_named, in_axes=({}, {0: 'batch'}), out_axes={})(w, x)
|
|
expected = f_positional(w, x)
|
|
self.assertAllClose(out, expected, check_dtypes=True)
|
|
|
|
# gradient
|
|
out = xmap(jax.grad(f_named, (0, 1), reduce_axes=('batch',)),
|
|
in_axes=({}, {0: 'batch'}),
|
|
out_axes=({}, {0: 'batch'}))(w, x)
|
|
expected = jax.grad(f_positional, (0, 1))(w, x)
|
|
self.assertAllClose(out, expected, check_dtypes=True)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|