Merge pull request #7296 from apaszke:xmap-flag

PiperOrigin-RevId: 384993575
This commit is contained in:
jax authors 2021-07-15 13:12:38 -07:00
commit e951ac8521
3 changed files with 30 additions and 18 deletions

View File

@ -32,7 +32,7 @@ from .._src.tree_util import _replace_nones
from ..api_util import (flatten_fun_nokwargs, flatten_axes, _ensure_index_tuple,
donation_vector)
from .._src import source_info_util
from ..config import config
from .._src.config import config
from ..errors import JAXTypeError
from ..interpreters import partial_eval as pe
from ..interpreters import pxla
@ -52,8 +52,6 @@ zip = safe_zip
xops = xc.ops
EXPERIMENTAL_SPMD_LOWERING = False
class FrozenDict(abc.Mapping):
def __init__(self, *args, **kwargs):
self.contents = dict(*args, **kwargs)
@ -695,7 +693,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
mesh_in_axes,
mesh_out_axes,
donated_invars,
EXPERIMENTAL_SPMD_LOWERING,
config.experimental_xmap_spmd_lowering,
*in_avals,
tile_by_mesh_axes=True,
do_resource_typecheck=None)
@ -1072,7 +1070,7 @@ core.initial_to_final_param_rules[xmap_p] = _xmap_initial_to_final_params
# -------- nested xmap handling --------
def _xmap_translation_rule(*args, **kwargs):
if EXPERIMENTAL_SPMD_LOWERING:
if config.experimental_xmap_spmd_lowering:
return _xmap_translation_rule_spmd(*args, **kwargs)
else:
return _xmap_translation_rule_replica(*args, **kwargs)
@ -1492,3 +1490,23 @@ def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, in_axes=0
return xmap(fun, in_axes=in_axes, out_axes={0: axis_name},
axis_resources={axis_name: 'devices'})(*args, **kwargs)
return f_pmapped
# -------- config flags --------
def _thread_local_flag_unsupported(_):
raise RuntimeError("thread-local xmap flags not supported!")
def _clear_compilation_cache(_):
make_xmap_callable.cache_clear() # type: ignore
try:
config.define_bool_state(
name="experimental_xmap_spmd_lowering",
default=False,
help=("When set, multi-device xmaps computations will be compiled through "
"the XLA SPMD partitioner instead of explicit cross-replica collectives. "
"Not supported on CPU!"),
update_global_hook=_clear_compilation_cache,
update_thread_local_hook=_thread_local_flag_unsupported)
except Exception:
raise ImportError("jax.experimental.maps has to be imported before JAX flags "
"are parsed")

View File

@ -37,7 +37,6 @@ from ._src.util import partial, prod, unzip2
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
from .lib import xla_bridge
from .interpreters import xla
from .experimental import maps
from .experimental.maps import mesh
@ -1062,13 +1061,11 @@ def with_and_without_mesh(f):
old_spmd_lowering_flag = False
def set_spmd_lowering_flag(val: bool):
global old_spmd_lowering_flag
maps.make_xmap_callable.cache_clear()
old_spmd_lowering_flag = maps.EXPERIMENTAL_SPMD_LOWERING
maps.EXPERIMENTAL_SPMD_LOWERING = val
old_spmd_lowering_flag = config.experimental_xmap_spmd_lowering
config.update('experimental_xmap_spmd_lowering', val)
def restore_spmd_lowering_flag():
maps.make_xmap_callable.cache_clear()
maps.EXPERIMENTAL_SPMD_LOWERING = old_spmd_lowering_flag
config.update('experimental_xmap_spmd_lowering', old_spmd_lowering_flag)
class _cached_property:
null = object()

View File

@ -216,13 +216,10 @@ class SPMDTestMixin:
if jtu.device_under_test() not in ['tpu', 'gpu']:
raise SkipTest
super().setUp()
jax.experimental.maps.make_xmap_callable.cache_clear()
self.old_lowering_flag = jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
jtu.set_spmd_lowering_flag(True)
def tearDown(self):
jax.experimental.maps.make_xmap_callable.cache_clear()
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = self.old_lowering_flag
jtu.restore_spmd_lowering_flag()
class XMapTest(XMapTestCase):
@ -372,7 +369,7 @@ class XMapTest(XMapTestCase):
(pxla.Chunked([2]), pxla.NoSharding(), pxla.NoSharding()))
self.assertEqual(y[0].sharding_spec.mesh_mapping,
(pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2),) * (len(mesh) - 2))
if maps.EXPERIMENTAL_SPMD_LOWERING:
if config.experimental_xmap_spmd_lowering:
hlo = jax.xla_computation(f)(x).as_hlo_text()
# Make sure that there are non-partial sharding specs in the HLO
self.assertRegex(hlo, r"sharding={devices=\[[0-9,]+\][0-9,]+}")
@ -745,7 +742,7 @@ class NewPrimitiveTest(XMapTestCase):
@jtu.with_and_without_mesh
def testGather(self, mesh, axis_resources):
if axis_resources and not jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING:
if axis_resources and not config.experimental_xmap_spmd_lowering:
raise SkipTest("pgather over mesh axes without SPMD lowering not implemented")
x = jnp.arange(12, dtype=np.float32).reshape((4, 3))
y = jnp.arange(35).reshape((5, 7)) % 3