rocm_jax/tests/xmap_test.py
2021-03-16 15:11:36 -04:00

1127 lines
44 KiB
Python

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
from contextlib import contextmanager
import functools
import itertools as it
import os
import unittest
from itertools import product, permutations
from typing import (Tuple, List, NamedTuple, Dict, Generator, Sequence, Set,
Any, Hashable, Iterable, Iterator, Union)
from unittest import SkipTest, skip, skipIf
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
from functools import partial
import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from jax import test_util as jtu
from jax import vmap
from jax import lax
from jax import core
from jax.core import NamedShape
from jax.experimental.maps import Mesh, mesh, xmap
from jax.lib import xla_bridge
from jax._src.util import curry, unzip2, split_list, prod
from jax._src.lax.lax import DotDimensionNumbers
from jax._src.lax.parallel import pgather
from jax.interpreters import batching, pxla
from jax.config import config
config.parse_flags_with_absl()
ignore_xmap_warning = functools.partial(
jtu.ignore_warning, message=".*is an experimental.*")
# TODO(mattjj): de-duplicate setUpModule and tearDownModule with pmap_test.py
# Run all tests with 8 CPU devices.
def setUpModule():
global prev_xla_flags
prev_xla_flags = os.getenv("XLA_FLAGS")
flags_str = prev_xla_flags or ""
# Don't override user-specified device count, or other XLA flags.
if "xla_force_host_platform_device_count" not in flags_str:
os.environ["XLA_FLAGS"] = (flags_str +
" --xla_force_host_platform_device_count=8")
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
if prev_xla_flags is None:
del os.environ["XLA_FLAGS"]
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
xla_bridge.get_backend.cache_clear()
# -------------------- Mesh parametrization helpers --------------------
MeshSpec = List[Tuple[str, int]]
@contextmanager
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
"""Test utility for setting up meshes given mesh data from `schedules`."""
# This is similar to the `with_mesh` function above, but isn't a decorator.
axis_names, shape = unzip2(named_shape)
size = prod(shape)
local_devices = list(jax.local_devices())
if len(local_devices) < size:
raise SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(local_devices[:size]).reshape(shape)
with mesh(mesh_devices, axis_names):
yield
def with_mesh_from_kwargs(f):
return lambda *args, **kwargs: with_mesh(kwargs['mesh'])(f)(*args, **kwargs)
def with_and_without_mesh(f):
return parameterized.named_parameters(
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
for name, mesh, axis_resources in (
('', (), ()),
('Mesh', (('x', 2),), (('i', 'x'),))
))(with_mesh_from_kwargs(f))
# -------------------- Itertools helpers --------------------
def partitions(s, k):
for indices in product(range(k), repeat=len(s)):
outs = [[] for _ in range(k)]
for i, elt in zip(indices, s):
outs[i].append(elt)
yield outs
def powerset(s):
s = list(s)
return it.chain.from_iterable(it.combinations(s, r) for r in range(len(s)+1))
# -------------------- vmap test helpers --------------------
ensure_bdim_p = core.Primitive('ensure_bdim')
ensure_bdim_p.def_abstract_eval(lambda x, **kwargs: core.raise_to_shaped(x))
def _ensure_bdim_batcher(frame, vals_in, dims_in, axis_name, bdim):
v, = vals_in
d, = dims_in
assert d is not batching.not_mapped
return jnp.moveaxis(v, d, bdim), bdim
batching.collective_rules[ensure_bdim_p] = _ensure_bdim_batcher
batching.primitive_batchers[ensure_bdim_p] = lambda v, d: (v[0], d[0])
core.axis_substitution_rules[ensure_bdim_p] = partial(jax._src.lax.parallel._subst_all_names_in_param,
'axis_name')
def ensure_bdim(x, axis_name, bdim):
return ensure_bdim_p.bind(x, axis_name=(axis_name,), bdim=bdim)
# -------------------- Axis resources generation --------------------
AxisResources = Dict[str, Union[str, Tuple[str, ...]]]
def schedules(sizes: Dict[str, int]
) -> Generator[Tuple[AxisResources, MeshSpec], None, None]:
"""Test utility generating xmap parallel schedules from logical names & sizes.
Args:
sizes: dict mapping logical axis name to its corresponding size.
Returns:
A generator producing finitely many values, where each value is a pair in
which the first element is a value suitable for xmap's axis_resources
argument and the second element is a list of pairs with the first element
representing a generated physical mesh axis name and the second element
representing a corresponding generated mesh axis size. The generated mesh
names/sizes can be used to define a physical mesh in tests.
This function doesn't generate schedules which map distinct logical axis names
to the same parallel resource name. It only generates parallel resources; the
rest are implicitly left for vectorization. Parallel resource names are
generated by prepending an 'r', 'r1', or 'r2' to the corresponding logical
name.
Examples:
>>> for sched in schedules({'i': 2, 'j': 4}):
... print(sched)
({}, [])
({'i': 'ri'}, [('ri', 1)])
({'i': 'ri'}, [('ri', 2)])
({'i': ('r1i', 'r2i')}, [('r1i', 1), ('r2i', 1)])
({'i': ('r1i', 'r2i')}, [('r1i', 1), ('r2i', 2)])
({'i': ('r1i', 'r2i')}, [('r1i', 2), ('r2i', 1)])
({'j': 'rj'}, [('rj', 1)])
({'j': 'rj'}, [('rj', 2)])
({'j': 'rj'}, [('rj', 4)])
({'j': ('r1j', 'r2j')}, [('r1j', 1), ('r2j', 1)])
({'j': ('r1j', 'r2j')}, [('r1j', 1), ('r2j', 2)])
({'j': ('r1j', 'r2j')}, [('r1j', 1), ('r2j', 4)])
({'j': ('r1j', 'r2j')}, [('r1j', 2), ('r2j', 1)])
({'j': ('r1j', 'r2j')}, [('r1j', 2), ('r2j', 2)])
({'j': ('r1j', 'r2j')}, [('r1j', 4), ('r2j', 1)])
({'i': 'ri', 'j': 'rj'}, [('ri', 1), ('rj', 1)])
({'i': 'ri', 'j': 'rj'}, [('ri', 1), ('rj', 2)])
({'i': 'ri', 'j': 'rj'}, [('ri', 1), ('rj', 4)])
({'i': 'ri', 'j': 'rj'}, [('ri', 2), ('rj', 1)])
({'i': 'ri', 'j': 'rj'}, [('ri', 2), ('rj', 2)])
({'i': 'ri', 'j': 'rj'}, [('ri', 2), ('rj', 4)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 1), ('r2j', 1)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 1), ('r2j', 2)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 1), ('r2j', 4)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 2), ('r2j', 1)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 2), ('r2j', 2)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 1), ('r1j', 4), ('r2j', 1)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 1), ('r2j', 1)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 1), ('r2j', 2)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 1), ('r2j', 4)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 2), ('r2j', 1)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 2), ('r2j', 2)])
({'i': 'ri', 'j': ('r1j', 'r2j')}, [('ri', 2), ('r1j', 4), ('r2j', 1)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 1), ('r1i', 1), ('r2i', 1)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 1), ('r1i', 1), ('r2i', 2)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 1), ('r1i', 2), ('r2i', 1)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 2), ('r1i', 1), ('r2i', 1)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 2), ('r1i', 1), ('r2i', 2)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 2), ('r1i', 2), ('r2i', 1)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 4), ('r1i', 1), ('r2i', 1)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 4), ('r1i', 1), ('r2i', 2)])
({'j': 'rj', 'i': ('r1i', 'r2i')}, [('rj', 4), ('r1i', 2), ('r2i', 1)])
"""
def divisors(n: int) -> List[int]:
return [m for m in range(1, n + 1) if not n % m]
def divisors2(n: int) -> Iterator[Tuple[int, int]]:
for k1 in divisors(n):
for k2 in divisors(n // k1):
yield (k1, k2)
# choose a subset of logical axis names to map to parallel resources
for names in powerset(sizes):
# partition that set of logical axis names into two subsets: one subset to
# map to one parallel resource axis and a second subset to map to two
# parallel resource axes.
for names1, names2 in partitions(names, 2):
# to avoid generating too many complex cases, we skip generating cases
# where more than one logical axis name is to be mapped to two parallel
# resource axes. comment out this line to generate more complex tests.
if len(names2) > 1: continue
# make up parallel resource axis names for each logical axis
axis_resources1 = ((name, 'r' + name) for name in names1)
axis_resources2 = ((name, ('r1' + name, 'r2' + name)) for name in names2)
axis_resources = dict(it.chain(axis_resources1, axis_resources2))
# make up sizes for each resource axis, where the size must divide the
# corresponding logical axis
for mesh_sizes1 in product(*(divisors(sizes[n]) for n in names1)):
for mesh_sizes2 in product(*(divisors2(sizes[n]) for n in names2)):
mesh_data1 = (('r' + name, size) for name, size in zip(names1, mesh_sizes1))
mesh_data2 = (pair for name, (size1, size2) in zip(names2, mesh_sizes2)
for pair in [('r1' + name, size1), ('r2' + name, size2)])
mesh_data = list(it.chain(mesh_data1, mesh_data2))
yield axis_resources, mesh_data
class XMapTestCase(jtu.BufferDonationTestCase):
def setUp(self):
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
super().setUp()
# A mixin that enables SPMD lowering tests
class SPMDTestMixin:
def setUp(self):
if jtu.device_under_test() != "tpu":
raise SkipTest
super().setUp()
jax.experimental.maps.make_xmap_callable.cache_clear()
self.old_lowering_flag = jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
def tearDown(self):
jax.experimental.maps.make_xmap_callable.cache_clear()
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = self.old_lowering_flag
class XMapTest(XMapTestCase):
@ignore_xmap_warning()
def testBasic(self):
local_devices = list(jax.local_devices())
if len(local_devices) < 4:
raise SkipTest("Test requires at least 4 local devices")
def f(a, b):
return a * 2, b * 4
devices = np.array(local_devices[:4]).reshape((2, 2))
with mesh(devices, ('x', 'y')):
fm = xmap(f,
in_axes=[{0: 'a', 1: 'b'}, ['c', ...]],
out_axes=[{0: 'a', 1: 'b'}, ['c', ...]],
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
ashape = (16, 8, 5)
a = jnp.arange(np.prod(ashape)).reshape(ashape)
bshape = (2, 7)
b = jnp.arange(np.prod(bshape)).reshape(bshape)
c, d = fm(a, b)
self.assertAllClose(c, a * 2)
self.assertAllClose(d, b * 4)
@ignore_xmap_warning()
@with_mesh([('x', 2), ('y', 2)])
def testCollectiveReduce(self):
fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4),
in_axes=[['a', 'b', ...], {0: 'c'}],
out_axes=[['b', ...], {0: 'c'}],
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
ashape = (16, 8, 5)
a = jnp.arange(np.prod(ashape)).reshape(ashape)
bshape = (2, 7)
b = jnp.arange(np.prod(bshape)).reshape(bshape)
c, d = fm(a, b)
self.assertAllClose(c, (a * 2).sum(0))
self.assertAllClose(d, b * 4)
@ignore_xmap_warning()
@with_mesh([('x', 2), ('y', 2)])
def testCollectivePermute2D(self):
perm = np.array([3, 1, 2, 0])
x = jnp.arange(4).reshape((2, 2))
result = xmap(lambda x: lax.pshuffle(x, ('i', 'j'), perm),
in_axes=['i', 'j', ...],
out_axes=['i', 'j', ...],
axis_resources={'i': 'x', 'j': 'y'})(x).reshape((-1,))
self.assertAllClose(result, perm)
@ignore_xmap_warning()
def testCollectivePermute1D(self):
perm = np.array([3, 1, 2, 0])
x = jnp.arange(4)
result = xmap(lambda x: lax.pshuffle(x, 'i', perm),
in_axes=['i', ...],
out_axes=['i', ...])(x)
self.assertAllClose(result, perm)
@ignore_xmap_warning()
@with_mesh([('x', 2), ('y', 2)])
def testOneLogicalTwoMeshAxesBasic(self):
def f(v):
return lax.psum(v * 2, 'a'), v * 4
fm = xmap(f, in_axes=['a', ...], out_axes=[{}, {1: 'a'}],
axis_resources={'a': ('x', 'y')})
vshape = (4, 5)
v = jnp.arange(np.prod(vshape)).reshape(vshape)
ans, ans2 = fm(v)
self.assertAllClose(ans, (v * 2).sum(0))
self.assertAllClose(ans2, v.T * 4)
@ignore_xmap_warning()
@with_mesh([('x', 2), ('y', 2)])
def testOneLogicalTwoMeshAxesSharding(self):
def f(v):
return v * 4
fxy = xmap(f, in_axes=['a', ...], out_axes={1: 'a'},
axis_resources={'a': ('x', 'y')})
fyx = xmap(f, in_axes=['a', ...], out_axes={1: 'a'},
axis_resources={'a': ('y', 'x')})
vshape = (4, 5)
v = jnp.arange(np.prod(vshape)).reshape(vshape)
zxy = fxy(v)
self.assertEqual(
zxy.sharding_spec,
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
(pxla.ShardedAxis(0), pxla.ShardedAxis(1))))
zyx = fyx(v)
self.assertEqual(
zyx.sharding_spec,
pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
(pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
@ignore_xmap_warning()
@with_mesh([('x', 2), ('y', 2)])
def testSkipFirstMeshDim(self):
def run(axis_resources):
return xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...],
axis_resources=axis_resources)(jnp.ones((4,)))
self.assertAllClose(run({'i': 'x'}), run({'i': 'y'}))
@ignore_xmap_warning()
@with_mesh([('x', 2)])
def testCompilationCache(self):
def f(x):
assert python_should_be_executing
return x * 2
fm = xmap(f,
in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': 'x'})
x = np.arange(8).reshape((2, 2, 2))
python_should_be_executing = True
fm(x)
python_should_be_executing = False
fm(x)
@parameterized.named_parameters(
{"testcase_name": name, "mesh": mesh, "axis_resources": axis_resources}
for name, mesh, axis_resources in (
('OneToOne', (('x', 2), ('y', 2)), (('a', 'y'), ('b', 'x'))),
('Multiple', (('x', 2), ('y', 2), ('z', 2)), (('a', 'y'), ('b', ('x', 'z')))),
))
@with_mesh_from_kwargs
@ignore_xmap_warning()
def testNestedMesh(self, mesh, axis_resources):
@partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}),
axis_resources=dict([axis_resources[0]]))
def f(x):
y = x * 2
@partial(xmap, in_axes={0: 'b'}, out_axes=({1: 'b'}, {}),
axis_resources=dict([axis_resources[1]]))
def h(y):
return jnp.sin(y), lax.psum(y, ('a', 'b'))
return h(y)
xshape = (4, 2, 5)
x = jnp.arange(np.prod(xshape)).reshape(xshape)
y = f(x)
self.assertAllClose(y, (jnp.sin(x * 2).transpose((1, 2, 0)), (x * 2).sum((0, 1))))
self.assertEqual(y[0].sharding_spec.sharding,
(pxla.Chunked([2]), pxla.NoSharding(), pxla.NoSharding()))
self.assertEqual(y[0].sharding_spec.mesh_mapping,
(pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2),) * (len(mesh) - 2))
@with_and_without_mesh
@ignore_xmap_warning()
def testMultipleCalls(self, mesh, axis_resources):
def f(x, y):
assert x.shape == y.shape == (3, 5)
return jnp.tensordot(x, y, axes=([1], [1]))
f_mapped = xmap(f,
in_axes=(['i', ...], ['j', ...]),
out_axes=['i', 'j', ...],
axis_resources=dict(axis_resources))
x = jnp.arange(30).reshape(2, 3, 5)
expected = jnp.einsum('imk,jnk->ijmn', x, x)
for i in range(10):
self.assertAllClose(f_mapped(x, x), expected)
@with_and_without_mesh
@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
@ignore_xmap_warning()
def testBufferDonation(self, mesh, axis_resources):
shard = lambda x: x
if axis_resources:
shard = xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...],
axis_resources=dict(axis_resources))
f = xmap(lambda x, y: x + y * 4,
in_axes=['i', ...], out_axes=['i', ...],
axis_resources=dict(axis_resources),
donate_argnums=0)
# The multiplications below disable some optimizations that prevent reuse
x = shard(jnp.zeros((2, 5)) * 4)
y = shard(jnp.ones((2, 5)) * 2)
f(x, y)
self.assertNotDeleted(y)
self.assertDeleted(x)
@ignore_xmap_warning()
def testControlFlow(self):
x = jnp.arange(5)
xmap(lambda x: lax.fori_loop(0, 10, lambda _, x: lax.psum(x, 'i'), x),
in_axes=['i', ...], out_axes=['i', ...])(x)
@with_and_without_mesh
@ignore_xmap_warning()
def testAxisSizes(self, mesh, axis_resources):
result = xmap(lambda: lax.axis_index('i'),
in_axes=(), out_axes=['i', ...],
axis_sizes={'i': 6},
axis_resources=dict(axis_resources))()
self.assertAllClose(result, jnp.arange(6, dtype=result.dtype))
def VmapOfXmapCases():
xmap_in_axes = ([{}] +
[{i: 'x'} for i in range(3)] +
[{i: 'x', j: 'y'} for i in range(4) for j in range(4) if i != j])
for xmap_dim_x, xmap_dim_y in product(xmap_in_axes, repeat=2):
xmap_axes = sorted(set(xmap_dim_x.values()) | set(xmap_dim_y.values()))
num_axes = len(xmap_axes)
if xmap_axes is None:
continue
xmap_out_axes = [dict(zip(dims, xmap_axes))
for dims in permutations(range(2 + num_axes), num_axes)]
for xmap_dim_z in xmap_out_axes:
for vmap_dim_x in [*range(2 + len(xmap_dim_x)), None]:
for vmap_dim_y in [*range(2 + len(xmap_dim_y)), None]:
if vmap_dim_x is None and vmap_dim_y is None:
continue
for vmap_dim_result in range(3):
for vmap_dim_z in range(2 + len(xmap_axes)):
for vmap_as_xmap in [False, True]:
yield {"testcase_name":
f"_xin={(sorted(xmap_dim_x.items()), sorted(xmap_dim_y.items()))}_"
f"xout={sorted(xmap_dim_z.items())}_vin={(vmap_dim_x, vmap_dim_y)}_"
f"vout={vmap_dim_z}_vresult={vmap_dim_result}_vmap_as_xmap={vmap_as_xmap}",
"xmap_in_axes": (xmap_dim_x, xmap_dim_y),
"xmap_out_axes": xmap_dim_z,
"vmap_in_axes": (vmap_dim_x, vmap_dim_y),
"vmap_out_axes": vmap_dim_z,
"vmap_result_axis": vmap_dim_result,
"vmap_as_xmap": vmap_as_xmap}
@parameterized.named_parameters(jtu.cases_from_list(VmapOfXmapCases()))
@ignore_xmap_warning()
def testNestedMap(self,
xmap_in_axes, xmap_out_axes,
vmap_in_axes, vmap_out_axes, vmap_result_axis,
vmap_as_xmap):
"""Test various vmap(xmap) and xmap(xmap) combinations.
The outer map always introduces a single dimension, the inner map introduces one or two.
"""
(xin_x, xin_y) = xmap_in_axes
(vin_x, vin_y) = vmap_in_axes
vmap_size = 7
xmap_sizes = {'x': 11, 'y': 13}
xshape = [2, 3]
yshape = [3, 5]
zshape = [2, 5]
xind = ['n', 'k']
yind = ['k', 'm']
zind = ['n', 'm']
f = lambda x, y: ensure_bdim(jnp.einsum('nk,km->nm', x, y), 'v', vmap_result_axis)
for pos, name in sorted(xin_x.items()):
xshape.insert(pos, xmap_sizes[name])
xind.insert(pos, name)
for pos, name in sorted(xin_y.items()):
yshape.insert(pos, xmap_sizes[name])
yind.insert(pos, name)
for pos, name in sorted(xmap_out_axes.items()):
zshape.insert(pos, xmap_sizes[name])
zind.insert(pos, name)
if vin_x is not None:
xshape.insert(vin_x, vmap_size)
xind.insert(vin_x, 'v')
if vin_y is not None:
yshape.insert(vin_y, vmap_size)
yind.insert(vin_y, 'v')
zshape.insert(vmap_out_axes, vmap_size)
zind.insert(vmap_out_axes, 'v')
if vmap_as_xmap:
do_vmap = partial(xmap,
in_axes=({vin_x: 'v'} if vin_x is not None else {},
{vin_y: 'v'} if vin_y is not None else {}),
out_axes={vmap_out_axes: 'v'})
else:
do_vmap = partial(vmap, in_axes=vmap_in_axes, out_axes=vmap_out_axes, axis_name='v')
fm = do_vmap(xmap(f, in_axes=xmap_in_axes, out_axes=xmap_out_axes))
fref = partial(jnp.einsum, f"{''.join(xind)},{''.join(yind)}->{''.join(zind)}")
rng = np.random.RandomState(0)
x = rng.randn(*xshape)
y = rng.randn(*yshape)
self.assertAllClose(fm(x, y), fref(x, y))
class XMapTestSPMD(SPMDTestMixin, XMapTest):
"""Re-executes all basic tests with the SPMD partitioner enabled"""
skipped_tests = {
"NestedMesh", # Nesting xmap calls is not supported in the SPMD lowering yet
"NestedMap", # Same as above
"CollectivePermute2D" # vmap of multidimensional permute not implemented yet
}
def setUp(self):
for skipped_name in self.skipped_tests:
if skipped_name in self._testMethodName:
raise SkipTest
super().setUp()
class NamedNumPyTest(XMapTestCase):
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{reduction.__name__}_axes={axes}_i={mapped_axis}",
"reduction": reduction, "axes": axes, "mapped_axis": mapped_axis}
for reduction in (jnp.sum, jnp.max, jnp.min, jnp.mean, jnp.var, jnp.std,
jscipy.special.logsumexp)
for axes in (0, 'i', (1,), ('i',), (0, 1), (0, 'i'), ('i', 0))
for mapped_axis in range(3)))
@ignore_xmap_warning()
def testReductions(self, reduction, axes, mapped_axis):
axes_t = axes if isinstance(axes, tuple) else (axes,)
reduces_i = 'i' in axes_t
ref_red = partial(reduction,
axis=tuple(mapped_axis if a == 'i' else a + (a >= mapped_axis)
for a in axes_t))
mapped_axis_after_red = mapped_axis - sum(axis < mapped_axis if axis != 'i' else 0
for axis in axes_t)
xmap_red = xmap(lambda x: reduction(x, axes),
in_axes={mapped_axis: 'i'},
out_axes=({} if 'i' in axes_t else {mapped_axis_after_red: 'i'}))
rng = np.random.RandomState(0)
x = rng.randn(2, 5, 6)
self.assertAllClose(ref_red(x), xmap_red(x))
class NamedRandomTest(XMapTestCase):
@curry
def parameterize_by_sampler(extra, f, subset):
if extra is None:
extra = [("", {})]
else:
extra = list(extra)
subset_fn = jtu.cases_from_list if subset else lambda x: x
return parameterized.named_parameters(subset_fn(
{"testcase_name": name + extra_name, "distr_sample": sample, **extra_kwargs}
for name, sample in [
("Uniform", jax.random.uniform),
("Normal", jax.random.normal),
("Bernoulli", partial(jax.random.bernoulli, p=0.5)),
("TruncatedNormal", partial(jax.random.truncated_normal, lower=-2, upper=2)),
]
for extra_name, extra_kwargs in extra))(f)
@parameterize_by_sampler(None, subset=False)
@ignore_xmap_warning()
def testSamplerSharding(self, distr_sample):
def sample(shape, map_size):
return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=shape),
in_axes=(), out_axes=[None, 'i', ...], axis_sizes={'i': map_size})()
replicated = sample((3,), 4)
self.assertTrue((replicated[:,[0]] == replicated).all())
sharded = sample(NamedShape(3, i=4), 4)
self.assertFalse((sharded[:,[0]] == sharded[:,1:]).all(1).any())
error = "The shape of axis i was specified as 4, but it really is 5"
with self.assertRaisesRegex(ValueError, error):
sample(NamedShape(3, i=4), 5)
@parameterize_by_sampler(
((f"_mesh={mesh}_resources={sorted(axis_resources.items())}",
{"axis_resources": tuple(axis_resources.items()), "mesh": tuple(mesh)})
for axis_resources, mesh in schedules({'i': 4, 'j': 6})), subset=True)
@with_mesh_from_kwargs
@ignore_xmap_warning()
def testSamplerResourceIndependence(self, distr_sample, axis_resources, mesh):
def sample(axis_resources):
return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=NamedShape(3, i=4, j=6)),
in_axes=(), out_axes=['i', 'j', ...], axis_sizes={'i': 4, 'j': 6},
axis_resources=axis_resources)()
self.assertAllClose(sample({}), sample(dict(axis_resources)))
class NamedNNTest(XMapTestCase):
@ignore_xmap_warning()
def testOneHot(self):
f = xmap(lambda x: jax.nn.one_hot([1, 2, 0], 3, axis='i'),
in_axes=['i', ...], out_axes=['i', ...])
expected = jnp.array([[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.]]).T
self.assertAllClose(f(jnp.ones((3,))), expected)
@ignore_xmap_warning()
def testOneHotOutOfBound(self):
f = xmap(lambda x: jax.nn.one_hot([-1, 3], 3, axis='i'),
in_axes=['i', ...], out_axes=['i', ...])
self.assertAllClose(f(jnp.ones((3,))), jnp.zeros((3, 2)))
@ignore_xmap_warning()
def testOneHotAxisSizeMismatch(self):
f = xmap(lambda x: jax.nn.one_hot([-1, 3], 3, axis='i'),
in_axes=['i', ...], out_axes=['i', ...])
with self.assertRaisesRegex(ValueError, "to match the size of axis i, but 3 != 5"):
f(jnp.ones((5,)))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_map_in={map_in}_map_out={map_out}_fan={fan}_distr={distr}",
"map_in": map_in, "map_out": map_out, "fan": fan,
"distr": distr}
for map_in, map_out in [(True, False), (False, True), (True, True)]
for fan in ['fan_in', 'fan_out', 'fan_avg']
for distr in ['uniform', 'normal', 'truncated_normal']))
@ignore_xmap_warning()
def testVarianceScaling(self, map_in, map_out, fan, distr):
shape = (80, 50, 7)
fan_in, fan_out = jax._src.nn.initializers._compute_fans(
NamedShape(*shape), 0, 1)
key = jax.random.PRNGKey(0)
base_scaling = partial(jax.nn.initializers.variance_scaling, 100, fan, distr)
ref_sampler = lambda: base_scaling(in_axis=0, out_axis=1)(key, shape)
if map_in and map_out:
out_axes=['i', 'o', ...]
named_shape = NamedShape(shape[2], i=shape[0], o=shape[1])
xmap_sampler = lambda: base_scaling(in_axis='i', out_axis='o')(key, named_shape)
elif map_in:
out_axes = ['i', ...]
named_shape = NamedShape(shape[1], shape[2], i=shape[0])
xmap_sampler = lambda: base_scaling(in_axis='i', out_axis=0)(key, named_shape)
elif map_out:
out_axes = [None, 'o', ...]
named_shape = NamedShape(shape[0], shape[2], o=shape[1])
xmap_sampler = lambda: base_scaling(in_axis=0, out_axis='o')(key, named_shape)
mapped_sampler = xmap(xmap_sampler,
in_axes=(), out_axes=out_axes,
axis_sizes={'i': shape[0], 'o': shape[1]})
self.assertAllClose(jnp.var(mapped_sampler()), jnp.var(ref_sampler()),
atol=1e-4, rtol=2e-2)
class NewPrimitiveTest(XMapTestCase):
def setUp(self):
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
def testGatherPositional(self):
x = jnp.arange(27).reshape((9, 3))
idx = jnp.array([1, 2, 1, 0]).reshape((2, 2))
self.assertAllClose(pgather(x, idx, 0), x[idx.ravel()].reshape((2, 2, 3)))
x_explode = x.reshape((3, 3, 3))
self.assertAllClose(pgather(x, idx, 0), pgather(x_explode, idx, (0, 1)))
@with_and_without_mesh
@ignore_xmap_warning()
def testGather(self, mesh, axis_resources):
if axis_resources and not jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING:
raise SkipTest("pgather over mesh axes without SPMD lowering not implemented")
x = jnp.arange(12, dtype=np.float32).reshape((4, 3))
y = jnp.arange(35).reshape((5, 7)) % 3
f = xmap(lambda src, idx: pgather(src, idx, 'j'),
in_axes=(['i', 'j'], ['k', 'm']),
out_axes=['i', 'k', 'm'],
axis_resources=dict(axis_resources))
f_ref = lambda x, y: x[:, y.reshape((-1,))].reshape((4, 5, 7))
self.assertAllClose(f(x, y), f_ref(x, y))
class NewPrimitiveTestSPMD(SPMDTestMixin, NewPrimitiveTest):
pass
AxisIndices = Tuple[int, ...]
MatchedAxisIndices = Tuple[AxisIndices, AxisIndices]
AxisNames = Tuple[str, ...]
class PdotTestSpec:
# The axis indices stored by a PdotTestSpec are all positional indices
# *before* taking mapping into account.
map_cont: MatchedAxisIndices
pos_cont: MatchedAxisIndices
map_batch: MatchedAxisIndices
pos_batch: MatchedAxisIndices
all_names: AxisNames
contract_names: AxisNames
batch_names: AxisNames
def __init__(self, map_cont, pos_cont, map_batch, pos_batch):
self.map_cont = map_cont
self.pos_cont = pos_cont
self.map_batch = map_batch
self.pos_batch = pos_batch
names = gen_axis_names()
self.contract_names = [next(names) for _ in range(len(map_cont[0]))]
self.batch_names = [next(names) for _ in range(len(map_batch[0]))]
self.all_names = self.contract_names + self.batch_names
@property
def dot_general_dim_nums(self):
lhs_contract = (*self.map_cont[0], *self.pos_cont[0])
rhs_contract = (*self.map_cont[1], *self.pos_cont[1])
lhs_batch = (*self.map_batch[0], *self.pos_batch[0])
rhs_batch = (*self.map_batch[1], *self.pos_batch[1])
return (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
@property
def pos_contract_after_mapping(self):
lhs = [i - sum(j < i for j in self._lhs_mapped) for i in self.pos_cont[0]]
rhs = [i - sum(j < i for j in self._rhs_mapped) for i in self.pos_cont[1]]
return (lhs, rhs)
@property
def pos_batch_after_mapping(self):
lhs = [i - sum(j < i for j in self._lhs_mapped) for i in self.pos_batch[0]]
rhs = [i - sum(j < i for j in self._rhs_mapped) for i in self.pos_batch[1]]
return (lhs, rhs)
@property
def _lhs_mapped(self):
return {*self.map_cont[0], *self.map_batch[0]}
@property
def _rhs_mapped(self):
return {*self.map_cont[1], *self.map_batch[1]}
@property
def lhs_in_axes(self):
axis_indices = [*self.map_cont[0], *self.map_batch[0]]
return dict(zip(axis_indices, self.all_names))
@property
def rhs_in_axes(self):
axis_indices = [*self.map_cont[1], *self.map_batch[1]]
return dict(zip(axis_indices, self.all_names))
def all_pdot_specs(lhs_shape, rhs_shape):
for matching in axis_matchings(lhs_shape, rhs_shape):
for lists in partitions(matching, 4):
yield PdotTestSpec(*map(unzip2, lists))
def axis_matchings(lhs_shape, rhs_shape):
def helper(start, exc1, exc2):
yield ()
for i in range(start, len(lhs_shape)):
d1 = lhs_shape[i]
if i not in exc1:
for j, d2 in enumerate(rhs_shape):
if d1 == d2 and j not in exc2:
for matches in helper(i + 1, exc1 | {i}, exc2 | {j}):
yield ((i, j), *matches)
return helper(0, set(), set())
def gen_axis_names():
names = 'ijkl'
for n in it.count(1):
for chars in product(names, repeat=n):
yield ''.join(chars)
def schedules_from_pdot_spec(
spec: PdotTestSpec, lhs_shape: Tuple[int], rhs_shape: Tuple[int]
) -> Generator[Tuple[AxisResources, MeshSpec], None, None]:
logical_sizes = {
name: shape[ax]
for shape, in_axes in [(lhs_shape, spec.lhs_in_axes),
(rhs_shape, spec.rhs_in_axes)]
for ax, name in in_axes.items()}
yield from schedules(logical_sizes)
class PDotTests(XMapTestCase):
@ignore_xmap_warning()
@with_mesh([('r1', 2)])
def testPdotBasic(self):
def f(x, y):
return lax.pdot(x, y, 'i')
f_mapped = xmap(f,
in_axes=[{1: 'i'}, {0: 'i'}],
out_axes={},
axis_resources={'i': 'r1'})
rng = np.random.RandomState(0)
x = rng.randn(3, 8)
y = rng.randn(8, 5)
z = f_mapped(x, y)
self.assertAllClose(z, jnp.dot(x, y))
@ignore_xmap_warning()
@with_mesh([('r1', 2)])
def testPdotBatching(self):
def f(x, y):
return lax.pdot(x, y, 'i')
rng = np.random.RandomState(0)
x = rng.randn(2, 3, 8)
y = rng.randn(2, 8, 5)
f_mapped = xmap(f,
in_axes=[{0: 'j', 2: 'i'}, {0: 'j', 1: 'i'}],
out_axes=['j', ...],
axis_resources={'i': 'r1'})
z = f_mapped(x, y)
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
@ignore_xmap_warning()
@with_mesh([('r1', 2)])
def testPdotBatchingShardUncontractedDim(self):
def f(x, y):
return lax.pdot(x, y, 'i')
rng = np.random.RandomState(0)
x = rng.randn(2, 3, 8)
y = rng.randn(2, 8, 5)
f_mapped = xmap(f,
in_axes=[{0: 'j', 2: 'i'}, {0: 'j', 1: 'i'}],
out_axes=['j', ...],
axis_resources={'j': 'r1'})
z = f_mapped(x, y)
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{next(test_counter)}",
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "pdot_spec": pdot_spec,
"axis_resources": axis_resources, "mesh_data": mesh_data}
for test_counter in [it.count()]
for lhs_shape, rhs_shape in product(
[(2,), (2, 4, 2, 1)],
repeat=2)
for pdot_spec in all_pdot_specs(lhs_shape, rhs_shape)
for axis_resources, mesh_data in schedules_from_pdot_spec(
pdot_spec, lhs_shape, rhs_shape)))
@ignore_xmap_warning()
def testPdotSystematic(self, lhs_shape, rhs_shape, pdot_spec, axis_resources,
mesh_data):
rng = jtu.rand_default(self.rng())
lhs = rng(lhs_shape, np.float32)
rhs = rng(rhs_shape, np.float32)
def pdot_fun(x, y):
# print(f'pdot(x:{x.aval.str_short()}, y:{y.aval.str_short()},\n'
# f' axis_name={contract_names},\n'
# f' pos_contract={spec.pos_contract_after_mapping}\n'
# f' pos_batch={spec.pos_batch_after_mapping})')
return jax.lax.pdot(x, y, axis_name=pdot_spec.contract_names,
pos_batch=pdot_spec.pos_batch_after_mapping,
pos_contract=pdot_spec.pos_contract_after_mapping)
fun = xmap(pdot_fun, in_axes=[pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes],
out_axes=[*pdot_spec.batch_names, ...],
axis_resources=axis_resources)
with with_mesh(mesh_data):
result = fun(lhs, rhs)
expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(result, expected, check_dtypes=False,
atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{next(test_counter)}",
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "pdot_spec": pdot_spec,
"axis_resources": axis_resources, "mesh_data": mesh_data}
for test_counter in [it.count()]
for lhs_shape, rhs_shape in product(
[(2,), (2, 4, 2, 1)],
repeat=2)
for pdot_spec in all_pdot_specs(lhs_shape, rhs_shape)
for axis_resources, mesh_data in schedules_from_pdot_spec(
pdot_spec, lhs_shape, rhs_shape)))
@ignore_xmap_warning()
def testPdotVJPSystematic(self, lhs_shape, rhs_shape, pdot_spec,
axis_resources, mesh_data):
rng = jtu.rand_default(self.rng())
lhs = rng(lhs_shape, np.float32)
rhs = rng(rhs_shape, np.float32)
expected_out, ref_vjp = jax.vjp(
lambda x, y: lax.dot_general(x, y, pdot_spec.dot_general_dim_nums),
lhs, rhs)
out_bar = rng(expected_out.shape, np.float32)
expected_lhs, expected_rhs = ref_vjp(out_bar)
def pdot_fun(x, y, out_bar):
pdot = partial(jax.lax.pdot,
axis_name=pdot_spec.contract_names,
pos_batch=pdot_spec.pos_batch_after_mapping,
pos_contract=pdot_spec.pos_contract_after_mapping)
_, pdot_vjp = jax.vjp(pdot, x, y)
return pdot_vjp(out_bar)
fun = xmap(pdot_fun,
in_axes=[pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes,
[*pdot_spec.batch_names, ...]],
out_axes=(pdot_spec.lhs_in_axes, pdot_spec.rhs_in_axes),
axis_resources=axis_resources)
with with_mesh(mesh_data):
lhs_bar, rhs_bar = fun(lhs, rhs, out_bar)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(lhs_bar, expected_lhs, check_dtypes=False,
atol=tol, rtol=tol)
self.assertAllClose(rhs_bar, expected_rhs, check_dtypes=False,
atol=tol, rtol=tol)
@ignore_xmap_warning()
def test_xeinsum_vector_dot(self):
rng = np.random.RandomState(0)
x = rng.randn(3)
y = rng.randn(3)
out = xmap(partial(jnp.einsum, '{i},{i}->'),
in_axes=(['i'], ['i']), out_axes=[])(x, y)
expected = np.einsum('i,i->', x, y)
self.assertAllClose(out, expected, check_dtypes=False)
@ignore_xmap_warning()
def test_xeinsum_outer_product(self):
rng = np.random.RandomState(0)
x = rng.randn(3)
y = rng.randn(3)
out = xmap(partial(jnp.einsum, '{i},{j}->{i,j}'),
in_axes=(['i'], ['j']), out_axes=['i', 'j'])(x, y)
expected = np.einsum('i,j->ij', x, y)
self.assertAllClose(out, expected, check_dtypes=True)
@ignore_xmap_warning()
def test_xeinsum_matmul(self):
rng = np.random.RandomState(0)
x = rng.randn(3, 4)
y = rng.randn(4, 5)
def check(spec):
out = xmap(partial(jnp.einsum, spec),
in_axes=(['i', 'j'], ['j', 'k']),
out_axes=['i', 'k'])(x, y)
expected = np.einsum('ij,jk->ik', x, y)
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
self.assertAllClose(out, expected, check_dtypes=True,
atol=tol, rtol=tol)
check('{i,j},{j,k}->{i,k}')
check('{i,j},{k,j}->{k,i}') # order of named axes in the spec doesn't matter!
check('{j},{k,j}->{k}')
check('{i,j},{j}->{i}')
check('{j},{j}->{}')
def test_xeinsum_no_named_axes_vector_dot(self):
rng = np.random.RandomState(0)
x = rng.randn(3)
y = rng.randn(3)
out = jnp.einsum('i,i->', x, y, _use_xeinsum=True)
expected = np.einsum('i,i->', x, y)
self.assertAllClose(out, expected, check_dtypes=False)
def test_xeinsum_no_named_axes_batch_vector_dot(self):
rng = np.random.RandomState(0)
x = rng.randn(3, 2)
y = rng.randn(3, 2)
out = jnp.einsum('ij,ij->i', x, y, _use_xeinsum=True)
expected = np.einsum('ij,ij->i', x, y)
self.assertAllClose(out, expected, check_dtypes=True)
def test_xeinsum_no_named_axes_reduce_sum(self):
rng = np.random.RandomState(0)
x = rng.randn(3)
y = rng.randn()
out = jnp.einsum('i,->', x, y, _use_xeinsum=True)
expected = np.einsum('i,->', x, y)
self.assertAllClose(out, expected, check_dtypes=True)
class XMapErrorTest(jtu.JaxTestCase):
def setUp(self):
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
super().setUp()
@ignore_xmap_warning()
@with_mesh([('x', 2)])
def testRepeatedAxisResource(self):
def f(v):
return v * 4
with self.assertRaisesRegex(ValueError, r"distinct resources.*specified \('x', 'x'\) for axis a"):
fxy = xmap(f, in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': ('x', 'x')})
@ignore_xmap_warning()
@with_mesh([('x', 2)])
def testNestedDifferentResources(self):
@partial(xmap, in_axes={0: 'a'}, out_axes={0: 'a'}, axis_resources={'a': 'x'})
def f(x):
with mesh(np.empty((), dtype=np.object_), ()):
@partial(xmap, in_axes={0: 'b'}, out_axes={0: 'b'})
def h(x):
return x
return h(x)
xshape = (2, 5, 6)
x = jnp.arange(np.prod(xshape)).reshape(xshape)
with self.assertRaisesRegex(RuntimeError,
"Changing the resource environment.*"):
f(x)
@ignore_xmap_warning()
def testEmptyArgumentTrees(self):
with self.assertRaisesRegex(ValueError, "Failed to infer size of axes: i."):
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...])({})
@ignore_xmap_warning()
@with_mesh([('x', 2), ('y', 2)])
def testAxesNotDivisibleByResources(self):
with self.assertRaisesRegex(ValueError, r"Size of axis i \(5\) is not divisible.*"
r"\(\('x', 'y'\), 4 in total\)"):
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...],
axis_sizes={'i': 5}, axis_resources={'i': ('x', 'y')})({})
@ignore_xmap_warning()
def testInconsistentAxisSizes(self):
x5 = jnp.arange(5)
x6 = jnp.arange(6)
error = (r"The size of axis i was previously inferred to be 5, but found an "
r"argument of shape \(6,\) with in_axes specification \['i', ...\]. "
r"Shape mismatch occurs in dimension 0: 6 != 5")
with self.assertRaisesRegex(ValueError, error):
xmap(lambda x, y: x, in_axes=(['i', ...], ['i', ...]), out_axes=['i', ...])(x5, x6)
with self.assertRaisesRegex(ValueError, error):
xmap(lambda x: x, in_axes=['i', ...], out_axes=['i', ...], axis_sizes={'i': 5})(x6)
@ignore_xmap_warning()
def testInAxesRankError(self):
error = (r"One of xmap arguments has an in_axes specification of \['i', 'j', ...\], "
r"which implies that it has at least 2 dimensions, but the argument has rank 1")
with self.assertRaisesRegex(ValueError, error):
xmap(lambda x: x, in_axes=['i', 'j', ...], out_axes=['j', 'i', ...])(jnp.ones((5,)))
@ignore_xmap_warning()
def testOutAxesRankError(self):
error = (r"One of xmap outputs has an out_axes specification of {1: 'i'}, "
r"which requires the result of the xmapped function to have at least "
r"1 positional dimensions, but it only has 0")
with self.assertRaisesRegex(ValueError, error):
xmap(lambda x: x, in_axes=['i', ...], out_axes={1: 'i'})(jnp.ones((5,)))
@ignore_xmap_warning()
def testNegativeAxes(self):
with self.assertRaisesRegex(ValueError, "xmap doesn't support negative axes in in_axes"):
xmap(lambda x: x, in_axes={-1: 'i'}, out_axes={0: 'i'})(jnp.ones((5,)))
with self.assertRaisesRegex(ValueError, "xmap doesn't support negative axes in out_axes"):
xmap(lambda x: x, in_axes={0: 'i'}, out_axes={-1: 'i'})(jnp.ones((5,)))
@ignore_xmap_warning()
def testListAxesRankAssertion(self):
error = (r"xmap argument has an in_axes specification of \['i', None\], which "
r"asserts that it should be of rank 2, but the argument has rank 1 "
r"\(and shape \(5,\)\)")
with self.assertRaisesRegex(ValueError, error):
xmap(lambda x: x, in_axes=['i', None], out_axes=['i', None])(jnp.ones((5,)))
error = (r"xmap output has an out_axes specification of \['i', None\], which "
r"asserts that it should be of rank 2, but the output has rank 3 "
r"\(and shape \(5, 2, 2\)\)")
with self.assertRaisesRegex(ValueError, error):
xmap(lambda x: x.reshape((2, 2)),
in_axes=['i', None], out_axes=['i', None])(jnp.ones((5, 4)))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())