mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #10408 from jakevdp:xmap-flake8
PiperOrigin-RevId: 443511304
This commit is contained in:
commit
60223fb5f1
@ -12,18 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
|
||||
from contextlib import contextmanager
|
||||
import functools
|
||||
import itertools as it
|
||||
import os
|
||||
import re
|
||||
import unittest
|
||||
from itertools import product, permutations
|
||||
from typing import (Tuple, List, NamedTuple, Dict, Generator, Sequence, Set,
|
||||
Any, Hashable, Iterable, Iterator, Union, Optional)
|
||||
from unittest import SkipTest, skip, skipIf
|
||||
from typing import (Tuple, List, Dict, Generator, Iterator, Union, Optional)
|
||||
from unittest import SkipTest
|
||||
|
||||
import numpy as np
|
||||
from absl.testing import absltest
|
||||
@ -37,17 +32,16 @@ from jax._src import test_util as jtu
|
||||
from jax import vmap
|
||||
from jax import lax
|
||||
from jax import core
|
||||
from jax.core import NamedShape, JaxprTypeError
|
||||
from jax.core import NamedShape
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import global_device_array
|
||||
from jax.experimental.pjit import pjit, with_sharding_constraint
|
||||
from jax.experimental.pjit import PartitionSpec as P
|
||||
from jax.experimental.maps import Mesh, xmap, serial_loop, SerialLoop
|
||||
from jax.experimental.maps import xmap, serial_loop, SerialLoop
|
||||
from jax.errors import JAXTypeError
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.util import curry, unzip2, split_list, prod, safe_zip
|
||||
from jax._src.lax.lax import DotDimensionNumbers
|
||||
from jax._src.util import curry, unzip2, prod, safe_zip
|
||||
from jax._src.lax.parallel import pgather
|
||||
from jax.interpreters import batching, pxla
|
||||
from jax.ad_checkpoint import checkpoint
|
||||
@ -702,7 +696,6 @@ class XMapTestSPMD(SPMDTestMixin, XMapTest):
|
||||
axis_resources={'a': 'x', 'b': 'y'})
|
||||
xshape = (8, 2, 4, 5)
|
||||
x = jnp.arange(np.prod(xshape)).reshape(xshape)
|
||||
y = f(x)
|
||||
hlo = f.lower(x).compiler_ir(dialect="hlo").as_hlo_text()
|
||||
match = re.search(r"sharding={devices=\[([0-9,]+)\][0-9,]+}", hlo)
|
||||
self.assertIsNot(match, None)
|
||||
@ -774,7 +767,6 @@ class NamedNumPyTest(XMapTestCase):
|
||||
for mapped_axis in range(3)))
|
||||
def testReductions(self, reduction, axes, mapped_axis):
|
||||
axes_t = axes if isinstance(axes, tuple) else (axes,)
|
||||
reduces_i = 'i' in axes_t
|
||||
ref_red = partial(reduction,
|
||||
axis=tuple(mapped_axis if a == 'i' else a + (a >= mapped_axis)
|
||||
for a in axes_t))
|
||||
@ -870,7 +862,7 @@ class NamedNNTest(XMapTestCase):
|
||||
base_scaling = partial(jax.nn.initializers.variance_scaling, 100, fan, distr)
|
||||
ref_sampler = lambda: base_scaling(in_axis=0, out_axis=1)(key, shape)
|
||||
if map_in and map_out:
|
||||
out_axes=['i', 'o', ...]
|
||||
out_axes = ['i', 'o', ...]
|
||||
named_shape = NamedShape(shape[2], i=shape[0], o=shape[1])
|
||||
xmap_sampler = lambda: base_scaling(in_axis='i', out_axis='o')(key, named_shape)
|
||||
elif map_in:
|
||||
@ -1019,7 +1011,7 @@ class XMapGDATest(XMapTestCase):
|
||||
ValueError,
|
||||
('Got an input GDA to xmap with different partitioning than '
|
||||
'specified in xmap. The partitioning must match.')):
|
||||
out1 = f(gda_obj)
|
||||
f(gda_obj)
|
||||
|
||||
|
||||
class NewPrimitiveTest(XMapTestCase):
|
||||
@ -1356,8 +1348,8 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
def f(v):
|
||||
return v * 4
|
||||
with self.assertRaisesRegex(ValueError, r"distinct resources.*specified \('x', 'x'\) for axis a"):
|
||||
fxy = xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
||||
axis_resources={'a': ('x', 'x')})
|
||||
xmap(f, in_axes=['a', ...], out_axes=['a', ...],
|
||||
axis_resources={'a': ('x', 'x')})
|
||||
|
||||
@jtu.with_mesh([('y', 2)])
|
||||
def testUndefinedAxisResource(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user