mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #5527 from google:remove-soft-pmap
PiperOrigin-RevId: 354328027
This commit is contained in:
commit
543dcb37e3
@ -42,7 +42,7 @@ pytype_library(
|
||||
"*_test.py",
|
||||
"**/*_test.py",
|
||||
],
|
||||
),
|
||||
) + ["experimental/maps.py"], # until xmap is moved out of experimental
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//third_party/py/jax/jaxlib:_pocketfft",
|
||||
|
@ -69,7 +69,6 @@ from .api import (
|
||||
shapecheck,
|
||||
ShapedArray,
|
||||
ShapeDtypeStruct,
|
||||
soft_pmap,
|
||||
# TODO(phawkins): hide tree* functions from jax, update callers to use
|
||||
# jax.tree_util.
|
||||
treedef_is_leaf,
|
||||
@ -86,6 +85,7 @@ from .api import (
|
||||
xla, # TODO(phawkins): update users to avoid this.
|
||||
xla_computation,
|
||||
)
|
||||
from .experimental.maps import soft_pmap
|
||||
from .version import __version__
|
||||
|
||||
# These submodules are separate because they are in an import cycle with
|
||||
|
@ -484,15 +484,6 @@ class XeinsumSpecParser:
|
||||
|
||||
### parallel primitives
|
||||
|
||||
def _allreduce_soft_pmap_rule(prim, reducer, vals, mapped, chunk_size,
|
||||
*, axis_name, axis_index_groups):
|
||||
if axis_index_groups is not None:
|
||||
raise NotImplementedError("soft_pmap does not yet support axis_index_groups")
|
||||
reduced_vals = [reducer(x, [0]) if m else x for x, m in zip(vals, mapped)]
|
||||
outs = prim.bind(*reduced_vals, axis_name=axis_name,
|
||||
axis_index_groups=axis_index_groups)
|
||||
return outs, (False,) * len(vals)
|
||||
|
||||
# This is only used for collectives that do not include the vmapped axis name,
|
||||
# which is why the rule is so simple.
|
||||
def _collective_batcher(prim, args, dims, **params):
|
||||
@ -593,8 +584,6 @@ def _psum_transpose_rule(cts, *args, axis_name, axis_index_groups):
|
||||
psum_p = core.Primitive('psum')
|
||||
psum_p.multiple_results = True
|
||||
psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
|
||||
pxla.soft_pmap_rules[psum_p] = \
|
||||
partial(_allreduce_soft_pmap_rule, psum_p, lax._reduce_sum)
|
||||
xla.parallel_translations[psum_p] = partial(_allreduce_translation_rule, lax.add_p) # type: ignore
|
||||
ad.deflinear2(psum_p, _psum_transpose_rule)
|
||||
pxla.multi_host_supported_collectives.add(psum_p)
|
||||
@ -957,14 +946,8 @@ def _axis_index_translation_rule(c, *, axis_name, axis_env, platform):
|
||||
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
|
||||
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
|
||||
|
||||
def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name):
|
||||
assert not vals and not mapped
|
||||
idx = axis_index(axis_name) # type: ignore
|
||||
return idx * chunk_size + np.arange(chunk_size, dtype=np.int32), True
|
||||
|
||||
axis_index_p = core.Primitive('axis_index')
|
||||
xla.parallel_translations[axis_index_p] = _axis_index_translation_rule
|
||||
pxla.soft_pmap_rules[axis_index_p] = _axis_index_soft_pmap_rule # type: ignore
|
||||
axis_index_p.def_abstract_eval(
|
||||
lambda *args, **params: ShapedArray((), np.int32))
|
||||
pxla.multi_host_supported_collectives.add(axis_index_p)
|
||||
@ -991,6 +974,7 @@ def _axis_index_bind(*, axis_name):
|
||||
axis_index_p.def_custom_bind(_axis_index_bind)
|
||||
|
||||
def _process_axis_index(self, frame):
|
||||
assert frame.size is not None
|
||||
return batching.BatchTracer(self, lax_numpy.arange(frame.size, dtype=np.int32), 0)
|
||||
batching.BatchTrace.process_axis_index = _process_axis_index # type: ignore
|
||||
|
||||
|
31
jax/api.py
31
jax/api.py
@ -1577,37 +1577,6 @@ def pmap(
|
||||
return f_pmapped
|
||||
|
||||
|
||||
def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, in_axes=0
|
||||
) -> Callable:
|
||||
if not config.omnistaging_enabled:
|
||||
raise NotImplementedError("soft_pmap requires omnistaging.")
|
||||
warn("soft_pmap is an experimental feature and probably has bugs!")
|
||||
_check_callable(fun)
|
||||
axis_name = core._TempAxisName(fun) if axis_name is None else axis_name
|
||||
|
||||
if any(axis != 0 for axis in tree_leaves(in_axes)):
|
||||
raise ValueError(f"soft_pmap in_axes leaves must be 0 or None, got {in_axes}")
|
||||
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def f_pmapped(*args, **kwargs):
|
||||
f = lu.wrap_init(fun)
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
in_axes_flat = flatten_axes("soft_pmap in_axes", in_tree, (in_axes, 0))
|
||||
axis_size = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "soft_pmap")
|
||||
for arg in args_flat: _check_arg(arg)
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
# See note about out_axes_thunk in pmap for the explanation of why we choose this key
|
||||
out_axes_thunk = HashableFunction(
|
||||
lambda: tuple(flatten_axes("soft_pmap out_axes", out_tree(), 0)),
|
||||
closure=())
|
||||
outs = pxla.soft_pmap(flat_fun, *args_flat, axis_name=axis_name,
|
||||
axis_size=axis_size, in_axes=tuple(in_axes_flat),
|
||||
out_axes_thunk=out_axes_thunk)
|
||||
return tree_unflatten(out_tree(), outs)
|
||||
return f_pmapped
|
||||
|
||||
|
||||
def mask(fun: Callable, in_shapes, out_shape=None) -> Callable:
|
||||
_check_callable(fun)
|
||||
unique_ids = masking.UniqueIds()
|
||||
|
@ -26,7 +26,8 @@ from .. import numpy as jnp
|
||||
from .. import core
|
||||
from .. import linear_util as lu
|
||||
from ..api import _check_callable, _check_arg
|
||||
from ..tree_util import tree_flatten, tree_unflatten, all_leaves
|
||||
from ..tree_util import (tree_flatten, tree_unflatten, all_leaves,
|
||||
_replace_nones, tree_map, tree_leaves)
|
||||
from ..api_util import flatten_fun_nokwargs, flatten_axes
|
||||
from ..interpreters import partial_eval as pe
|
||||
from ..interpreters import pxla
|
||||
@ -839,3 +840,27 @@ def subst_eqn_axis_names(eqn, axis_subst: Dict[AxisName, Tuple[AxisName]]):
|
||||
axis_names = (axis_names,)
|
||||
new_axis_names = sum((axis_subst.get(name, (name,)) for name in axis_names), ())
|
||||
return eqn._replace(params=dict(eqn.params, axis_name=new_axis_names))
|
||||
|
||||
|
||||
# -------- soft_pmap --------
|
||||
|
||||
def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, in_axes=0
|
||||
) -> Callable:
|
||||
warn("soft_pmap is an experimental feature and probably has bugs!")
|
||||
_check_callable(fun)
|
||||
axis_name = core._TempAxisName(fun) if axis_name is None else axis_name
|
||||
|
||||
if any(axis != 0 for axis in tree_leaves(in_axes)):
|
||||
raise ValueError(f"soft_pmap in_axes leaves must be 0 or None, got {in_axes}")
|
||||
proxy = object()
|
||||
in_axes = _replace_nones(proxy, in_axes)
|
||||
in_axes = tree_map(lambda i: {i: axis_name} if i is not proxy else {}, in_axes)
|
||||
|
||||
|
||||
@wraps(fun)
|
||||
def f_pmapped(*args, **kwargs):
|
||||
mesh_devices = np.array(xb.local_devices())
|
||||
with mesh(mesh_devices, ['devices']):
|
||||
return xmap(fun, in_axes=in_axes, out_axes={0: axis_name},
|
||||
axis_resources={axis_name: 'devices'})(*args, **kwargs)
|
||||
return f_pmapped
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union, Sequence
|
||||
|
||||
import jax
|
||||
from ..config import config
|
||||
@ -28,16 +28,24 @@ from . import partial_eval as pe
|
||||
|
||||
map = safe_map
|
||||
|
||||
def batch(fun: lu.WrappedFun, axis_name, axis_size, in_dims, out_dim_dests,
|
||||
BatchDim = Optional[int]
|
||||
BatchDims = Sequence[BatchDim]
|
||||
AxesSpec = Union[Callable[[], BatchDims], BatchDims]
|
||||
|
||||
def batch(fun: lu.WrappedFun, axis_name: core.AxisName,
|
||||
axis_size: Optional[int], in_dims: AxesSpec, out_dim_dests: AxesSpec,
|
||||
) -> lu.WrappedFun:
|
||||
# anlogue of `jvp` in ad.py
|
||||
# TODO(mattjj,apaszke): change type of axis_size to be int, not Optional[int]
|
||||
fun, out_dims_thunk = batch_subtrace(fun)
|
||||
return _match_axes(batchfun(fun, axis_name, axis_size, in_dims),
|
||||
axis_size, out_dims_thunk, out_dim_dests)
|
||||
axis_size, in_dims, out_dims_thunk, out_dim_dests)
|
||||
|
||||
@lu.transformation
|
||||
def batchfun(axis_name, axis_size, in_dims, *in_vals):
|
||||
# analogue of `jvpfun` in ad.py
|
||||
if axis_size is None:
|
||||
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
|
||||
in_dims = in_dims() if callable(in_dims) else in_dims
|
||||
in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
|
||||
and not isinstance(core.get_aval(x), core.AbstractUnit) # non-omnistaging
|
||||
@ -60,7 +68,9 @@ def batch_subtrace(main, in_dims, *in_vals):
|
||||
yield out_vals, out_dims
|
||||
|
||||
@lu.transformation
|
||||
def _match_axes(axis_size, out_dims_thunk, out_dim_dests, *in_vals):
|
||||
def _match_axes(axis_size, in_dims, out_dims_thunk, out_dim_dests, *in_vals):
|
||||
if axis_size is None:
|
||||
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
|
||||
out_vals = yield in_vals, {}
|
||||
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
|
||||
out_dims = out_dims_thunk()
|
||||
|
@ -45,14 +45,13 @@ from .. import core
|
||||
from .. import linear_util as lu
|
||||
from .. import lazy
|
||||
from ..abstract_arrays import array_types
|
||||
from ..core import ConcreteArray, ShapedArray, Var, Literal
|
||||
from ..core import ConcreteArray, ShapedArray
|
||||
from .._src.util import (partial, unzip2, unzip3, prod, safe_map, safe_zip,
|
||||
extend_name_stack, wrap_name, assert_unreachable,
|
||||
tuple_insert, tuple_delete, taggedtuple, curry)
|
||||
from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
from ..tree_util import tree_flatten, tree_map
|
||||
from .batching import broadcast, not_mapped, moveaxis
|
||||
from . import batching
|
||||
from . import partial_eval as pe
|
||||
from . import xla
|
||||
@ -1563,7 +1562,6 @@ def vtile(f_flat,
|
||||
|
||||
_forbidden_primitives = {
|
||||
'xla_pmap': 'pmap',
|
||||
'soft_pmap': 'soft_pmap',
|
||||
'sharded_call': 'sharded_jit',
|
||||
}
|
||||
def _sanitize_mesh_jaxpr(jaxpr):
|
||||
@ -1597,166 +1595,6 @@ def mesh_sharding_specs(axis_sizes, axis_names):
|
||||
return ShardingSpec(sharding, mesh_mapping)
|
||||
return mk_sharding_spec
|
||||
|
||||
|
||||
# ------------------- soft_pmap -------------------
|
||||
|
||||
def soft_pmap_impl(fun: lu.WrappedFun, *args, axis_name, axis_size, in_axes, out_axes_thunk):
|
||||
abstract_args = unsafe_map(xla.abstractify, args)
|
||||
compiled_fun = _soft_pmap_callable(fun, axis_name, axis_size, in_axes, out_axes_thunk,
|
||||
*abstract_args)
|
||||
return compiled_fun(*args)
|
||||
|
||||
@lu.cache
|
||||
def _soft_pmap_callable(fun, axis_name, axis_size, in_axes, out_axes_thunk, *avals):
|
||||
mapped_avals = [core.mapped_aval(axis_size, in_axis, aval) if in_axis is not None else aval
|
||||
for in_axis, aval in safe_zip(in_axes, avals)]
|
||||
with core.extend_axis_env(axis_name, axis_size, None): # type: ignore
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_avals)
|
||||
out_axes = out_axes_thunk()
|
||||
assert all(out_axis == 0 for out_axis in out_axes)
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
num_devices = xb.local_device_count()
|
||||
chunk_size, ragged = divmod(axis_size, num_devices)
|
||||
if ragged:
|
||||
msg = f"number of devices {num_devices} must divide axis size {axis_size}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
jaxpr, _, consts = _soft_pmap_jaxpr(jaxpr, consts, in_axes,
|
||||
axis_name, axis_size, chunk_size)
|
||||
jaxpr_replicas = xla.jaxpr_replicas(jaxpr)
|
||||
if jaxpr_replicas != 1: raise NotImplementedError
|
||||
|
||||
tuple_args = len(avals) > 100 # pass long arg lists as tuple for TPU
|
||||
|
||||
c = xb.make_computation_builder("soft_pmap_{}".format(fun.__name__))
|
||||
xla_consts = map(partial(xb.constant, c), consts)
|
||||
chunked_avals = [core.unmapped_aval(chunk_size, in_axis, aval) if in_axis is not None else aval
|
||||
for in_axis, aval in safe_zip(in_axes, mapped_avals)]
|
||||
xla_args, _ = xla._xla_callable_args(c, chunked_avals, tuple_args)
|
||||
axis_env = xla.AxisEnv(num_devices, (axis_name,), (num_devices,))
|
||||
out_nodes = xla.jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts,
|
||||
'soft_pmap', *xla_args)
|
||||
built = c.Build(xops.Tuple(c, out_nodes))
|
||||
|
||||
compile_options = xb.get_compile_options(
|
||||
num_replicas=num_devices, num_partitions=1, device_assignment=None)
|
||||
compile_options.tuple_arguments = tuple_args
|
||||
backend = xb.get_backend(None)
|
||||
compiled = xla.backend_compile(backend, built, compile_options)
|
||||
|
||||
input_specs = [
|
||||
ShardingSpec(
|
||||
sharding=tuple_insert((_UNSHARDED_INSTANCE,) *
|
||||
(aval.ndim - 1), in_axis, Chunked(num_devices)),
|
||||
mesh_mapping=[ShardedAxis(0)])
|
||||
if in_axis is not None else ShardingSpec(
|
||||
sharding=[_UNSHARDED_INSTANCE] * aval.ndim,
|
||||
mesh_mapping=[Replicated(num_devices)])
|
||||
for aval, in_axis in safe_zip(avals, in_axes)
|
||||
]
|
||||
input_indices = [spec and spec_to_indices(aval.shape, spec)
|
||||
for aval, spec in safe_zip(avals, input_specs)]
|
||||
handle_args = partial(shard_args, compiled.local_devices(), input_indices)
|
||||
handle_outs = soft_pmap_avals_to_results_handler(num_devices, chunk_size, out_avals)
|
||||
|
||||
return partial(execute_replicated, compiled, backend, handle_args, handle_outs)
|
||||
|
||||
def _soft_pmap_jaxpr(jaxpr, consts, in_axes, axis_name, axis_size, chunk_size):
|
||||
assert all(in_axis is None or in_axis == 0 for in_axis in in_axes), in_axes
|
||||
mapped_invars = [in_axis is not None for in_axis in in_axes]
|
||||
fun = partial(_soft_pmap_interp, chunk_size, jaxpr, consts, mapped_invars)
|
||||
in_avals = [core.unmapped_aval(chunk_size, in_axis, v.aval) if in_axis is not None else v.aval
|
||||
for v, in_axis in safe_zip(jaxpr.invars, in_axes)]
|
||||
with core.extend_axis_env(axis_name, axis_size, None):
|
||||
return pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
|
||||
|
||||
def _soft_pmap_interp(chunk_size, jaxpr, consts, mapped_invars, *args):
|
||||
|
||||
env: Dict[Var, Tuple[Any, bool]] = {}
|
||||
|
||||
def read(atom: Union[Var, Literal]) -> Tuple[Any, bool]:
|
||||
if isinstance(atom, Literal):
|
||||
return (atom.val, False)
|
||||
else:
|
||||
return env[atom]
|
||||
|
||||
def write(v: Var, val: Any, mapped: bool) -> None:
|
||||
env[v] = (val, mapped)
|
||||
|
||||
write(core.unitvar, core.unit, False)
|
||||
map(write, jaxpr.constvars, consts, (False,) * len(consts))
|
||||
map(write, jaxpr.invars, args, mapped_invars)
|
||||
for eqn in jaxpr.eqns:
|
||||
in_vals, in_mapped = unzip2(map(read, eqn.invars))
|
||||
if eqn.primitive in xla.parallel_translations:
|
||||
rule = soft_pmap_rules[eqn.primitive]
|
||||
out_vals, out_mapped = rule(in_vals, in_mapped, chunk_size, **eqn.params)
|
||||
if not eqn.primitive.multiple_results:
|
||||
out_vals, out_mapped = [out_vals], [out_mapped]
|
||||
elif isinstance(eqn.primitive, core.CallPrimitive):
|
||||
# we just inline here for convenience
|
||||
call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
|
||||
out_vals = _soft_pmap_interp(chunk_size, call_jaxpr, (), in_mapped, *in_vals)
|
||||
out_mapped = [True] * len(out_vals)
|
||||
elif isinstance(eqn.primitive, core.MapPrimitive):
|
||||
raise NotImplementedError # TODO
|
||||
else:
|
||||
if any(in_mapped):
|
||||
rule = batching.get_primitive_batcher(eqn.primitive, None)
|
||||
in_axes = [0 if m else batching.not_mapped for m in in_mapped]
|
||||
out_vals, out_axes = rule(in_vals, in_axes, **eqn.params)
|
||||
if not eqn.primitive.multiple_results:
|
||||
out_vals, out_axes = [out_vals], [out_axes]
|
||||
out_vals = [moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x
|
||||
for x, d in safe_zip(out_vals, out_axes)]
|
||||
out_mapped = [d is not not_mapped for d in out_axes]
|
||||
else:
|
||||
out_vals = eqn.primitive.bind(*in_vals, **eqn.params)
|
||||
if not eqn.primitive.multiple_results:
|
||||
out_vals = [out_vals]
|
||||
out_mapped = [False for _ in out_vals]
|
||||
map(write, eqn.outvars, out_vals, out_mapped)
|
||||
|
||||
out_vals, out_mapped = unzip2(map(read, jaxpr.outvars))
|
||||
out_vals = [out if mapped else broadcast(out, chunk_size, 0)
|
||||
for out, mapped in safe_zip(out_vals, out_mapped)]
|
||||
return out_vals
|
||||
|
||||
# TODO(mattjj): dedup w/ with other aval_to_result_handler via ShardingSpec
|
||||
def soft_pmap_avals_to_results_handler(num_devices, chunk_size, out_avals):
|
||||
nouts = len(out_avals)
|
||||
handlers = [soft_pmap_aval_to_result_handler(chunk_size, num_devices, aval)
|
||||
for aval in out_avals]
|
||||
def handler(out_bufs):
|
||||
buffers = [[result_to_populate] * num_devices for _ in range(nouts)]
|
||||
for r, tuple_buf in enumerate(out_bufs):
|
||||
for i, buf in enumerate(tuple_buf):
|
||||
buffers[i][r] = buf
|
||||
assert not any(buf is result_to_populate for bufs in buffers
|
||||
for buf in bufs)
|
||||
return [h(bufs) for h, bufs in safe_zip(handlers, buffers)]
|
||||
return handler
|
||||
|
||||
def soft_pmap_aval_to_result_handler(chunk_size, num_devices, aval):
|
||||
axis_size = chunk_size * num_devices
|
||||
if aval is core.abstract_unit:
|
||||
return lambda _: core.unit
|
||||
elif isinstance(aval, core.ShapedArray):
|
||||
new_aval = aval.update(shape=(axis_size,) + aval.shape)
|
||||
spec = ShardingSpec(
|
||||
sharding=(Chunked(num_devices),) + (_UNSHARDED_INSTANCE,) * aval.ndim,
|
||||
mesh_mapping=(ShardedAxis(0),))
|
||||
return lambda bufs: ShardedDeviceArray(new_aval, spec, bufs)
|
||||
else:
|
||||
raise TypeError(aval)
|
||||
|
||||
soft_pmap_p = core.MapPrimitive('soft_pmap')
|
||||
soft_pmap = soft_pmap_p.bind
|
||||
soft_pmap_p.def_impl(soft_pmap_impl)
|
||||
|
||||
soft_pmap_rules: Dict[core.Primitive, Callable] = {}
|
||||
|
||||
@contextmanager
|
||||
def maybe_extend_axis_env(*args, **kwargs):
|
||||
with core.extend_axis_env(*args, **kwargs):
|
||||
|
@ -35,8 +35,8 @@ from jax import tree_util
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax.core import ShapedArray
|
||||
from jax.api import (pmap, soft_pmap, jit, vmap, jvp, grad, make_jaxpr,
|
||||
linearize, device_put)
|
||||
from jax import (pmap, soft_pmap, jit, vmap, jvp, grad, make_jaxpr,
|
||||
linearize, device_put)
|
||||
from jax.lib import xla_bridge
|
||||
from jax._src.util import prod, safe_map
|
||||
from jax.interpreters import pxla
|
||||
@ -100,9 +100,6 @@ def tearDownModule():
|
||||
os.environ["XLA_FLAGS"] = prev_xla_flags
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
|
||||
ignore_soft_pmap_warning = partial(
|
||||
jtu.ignore_warning, message="soft_pmap is an experimental.*")
|
||||
|
||||
ignore_jit_of_pmap_warning = partial(
|
||||
jtu.ignore_warning, message=".*jit-of-pmap.*")
|
||||
|
||||
@ -1239,7 +1236,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(r, arr + 1)
|
||||
self.assertEqual(len(r.device_buffers), 6)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapBatchMatmul(self):
|
||||
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
|
||||
n = 4 * xla_bridge.device_count()
|
||||
@ -1249,7 +1246,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected = np.einsum('nij,njk->nik', xs, ys)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapBatchMatmulJit(self):
|
||||
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
|
||||
n = 4 * xla_bridge.device_count()
|
||||
@ -1259,7 +1256,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected = np.einsum('nij,njk->nik', xs, ys)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapPsumConstant(self):
|
||||
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
|
||||
n = 4 * xla_bridge.device_count()
|
||||
@ -1269,7 +1266,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected = n * np.ones(n)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapPsum(self):
|
||||
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
|
||||
n = 4 * xla_bridge.device_count()
|
||||
@ -1279,7 +1276,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected = np.ones(n) / n
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapAxisIndex(self):
|
||||
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
|
||||
n = 4 * xla_bridge.device_count()
|
||||
@ -1289,7 +1286,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected = 2 * np.arange(n)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapOfJit(self):
|
||||
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
|
||||
n = 4 * xla_bridge.device_count()
|
||||
@ -1299,7 +1296,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected = 3 * np.arange(n)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapNested(self):
|
||||
raise SkipTest("not implemented") # TODO(mattjj): re-implement
|
||||
n = 4 * xla_bridge.device_count()
|
||||
@ -1314,7 +1311,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected = np.arange(n ** 2).reshape(n, n).T
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
@ignore_xmap_warning()
|
||||
def testGradOfSoftPmap(self):
|
||||
raise SkipTest("not implemented") # TODO(mattjj): re-implement
|
||||
n = 4 * xla_bridge.device_count()
|
||||
@ -1327,7 +1324,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
expected = np.repeat(np.arange(n)[:, None], n, axis=1)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_soft_pmap_warning()
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapDevicePersistence(self):
|
||||
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
|
||||
device_count = xla_bridge.device_count()
|
||||
|
Loading…
x
Reference in New Issue
Block a user