Merge pull request #10408 from jakevdp:xmap-flake8

PiperOrigin-RevId: 443511304
This commit is contained in:
jax authors 2022-04-21 16:32:23 -07:00
commit 60223fb5f1

View File

@ -12,18 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
from contextlib import contextmanager
import functools
import itertools as it
import os
import re
import unittest
from itertools import product, permutations
from typing import (Tuple, List, NamedTuple, Dict, Generator, Sequence, Set,
Any, Hashable, Iterable, Iterator, Union, Optional)
from unittest import SkipTest, skip, skipIf
from typing import (Tuple, List, Dict, Generator, Iterator, Union, Optional)
from unittest import SkipTest
import numpy as np
from absl.testing import absltest
@ -37,17 +32,16 @@ from jax._src import test_util as jtu
from jax import vmap
from jax import lax
from jax import core
from jax.core import NamedShape, JaxprTypeError
from jax.core import NamedShape
from jax.experimental import maps
from jax.experimental import global_device_array
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.experimental.pjit import PartitionSpec as P
from jax.experimental.maps import Mesh, xmap, serial_loop, SerialLoop
from jax.experimental.maps import xmap, serial_loop, SerialLoop
from jax.errors import JAXTypeError
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.util import curry, unzip2, split_list, prod, safe_zip
from jax._src.lax.lax import DotDimensionNumbers
from jax._src.util import curry, unzip2, prod, safe_zip
from jax._src.lax.parallel import pgather
from jax.interpreters import batching, pxla
from jax.ad_checkpoint import checkpoint
@ -702,7 +696,6 @@ class XMapTestSPMD(SPMDTestMixin, XMapTest):
axis_resources={'a': 'x', 'b': 'y'})
xshape = (8, 2, 4, 5)
x = jnp.arange(np.prod(xshape)).reshape(xshape)
y = f(x)
hlo = f.lower(x).compiler_ir(dialect="hlo").as_hlo_text()
match = re.search(r"sharding={devices=\[([0-9,]+)\][0-9,]+}", hlo)
self.assertIsNot(match, None)
@ -774,7 +767,6 @@ class NamedNumPyTest(XMapTestCase):
for mapped_axis in range(3)))
def testReductions(self, reduction, axes, mapped_axis):
axes_t = axes if isinstance(axes, tuple) else (axes,)
reduces_i = 'i' in axes_t
ref_red = partial(reduction,
axis=tuple(mapped_axis if a == 'i' else a + (a >= mapped_axis)
for a in axes_t))
@ -870,7 +862,7 @@ class NamedNNTest(XMapTestCase):
base_scaling = partial(jax.nn.initializers.variance_scaling, 100, fan, distr)
ref_sampler = lambda: base_scaling(in_axis=0, out_axis=1)(key, shape)
if map_in and map_out:
out_axes=['i', 'o', ...]
out_axes = ['i', 'o', ...]
named_shape = NamedShape(shape[2], i=shape[0], o=shape[1])
xmap_sampler = lambda: base_scaling(in_axis='i', out_axis='o')(key, named_shape)
elif map_in:
@ -1019,7 +1011,7 @@ class XMapGDATest(XMapTestCase):
ValueError,
('Got an input GDA to xmap with different partitioning than '
'specified in xmap. The partitioning must match.')):
out1 = f(gda_obj)
f(gda_obj)
class NewPrimitiveTest(XMapTestCase):
@ -1356,8 +1348,8 @@ class XMapErrorTest(jtu.JaxTestCase):
def f(v):
return v * 4
with self.assertRaisesRegex(ValueError, r"distinct resources.*specified \('x', 'x'\) for axis a"):
fxy = xmap(f, in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': ('x', 'x')})
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
axis_resources={'a': ('x', 'x')})
@jtu.with_mesh([('y', 2)])
def testUndefinedAxisResource(self):