mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #7296 from apaszke:xmap-flag
PiperOrigin-RevId: 384993575
This commit is contained in:
commit
e951ac8521
@ -32,7 +32,7 @@ from .._src.tree_util import _replace_nones
|
||||
from ..api_util import (flatten_fun_nokwargs, flatten_axes, _ensure_index_tuple,
|
||||
donation_vector)
|
||||
from .._src import source_info_util
|
||||
from ..config import config
|
||||
from .._src.config import config
|
||||
from ..errors import JAXTypeError
|
||||
from ..interpreters import partial_eval as pe
|
||||
from ..interpreters import pxla
|
||||
@ -52,8 +52,6 @@ zip = safe_zip
|
||||
|
||||
xops = xc.ops
|
||||
|
||||
EXPERIMENTAL_SPMD_LOWERING = False
|
||||
|
||||
class FrozenDict(abc.Mapping):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.contents = dict(*args, **kwargs)
|
||||
@ -695,7 +693,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
mesh_in_axes,
|
||||
mesh_out_axes,
|
||||
donated_invars,
|
||||
EXPERIMENTAL_SPMD_LOWERING,
|
||||
config.experimental_xmap_spmd_lowering,
|
||||
*in_avals,
|
||||
tile_by_mesh_axes=True,
|
||||
do_resource_typecheck=None)
|
||||
@ -1072,7 +1070,7 @@ core.initial_to_final_param_rules[xmap_p] = _xmap_initial_to_final_params
|
||||
# -------- nested xmap handling --------
|
||||
|
||||
def _xmap_translation_rule(*args, **kwargs):
|
||||
if EXPERIMENTAL_SPMD_LOWERING:
|
||||
if config.experimental_xmap_spmd_lowering:
|
||||
return _xmap_translation_rule_spmd(*args, **kwargs)
|
||||
else:
|
||||
return _xmap_translation_rule_replica(*args, **kwargs)
|
||||
@ -1492,3 +1490,23 @@ def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, in_axes=0
|
||||
return xmap(fun, in_axes=in_axes, out_axes={0: axis_name},
|
||||
axis_resources={axis_name: 'devices'})(*args, **kwargs)
|
||||
return f_pmapped
|
||||
|
||||
# -------- config flags --------
|
||||
|
||||
def _thread_local_flag_unsupported(_):
|
||||
raise RuntimeError("thread-local xmap flags not supported!")
|
||||
def _clear_compilation_cache(_):
|
||||
make_xmap_callable.cache_clear() # type: ignore
|
||||
|
||||
try:
|
||||
config.define_bool_state(
|
||||
name="experimental_xmap_spmd_lowering",
|
||||
default=False,
|
||||
help=("When set, multi-device xmaps computations will be compiled through "
|
||||
"the XLA SPMD partitioner instead of explicit cross-replica collectives. "
|
||||
"Not supported on CPU!"),
|
||||
update_global_hook=_clear_compilation_cache,
|
||||
update_thread_local_hook=_thread_local_flag_unsupported)
|
||||
except Exception:
|
||||
raise ImportError("jax.experimental.maps has to be imported before JAX flags "
|
||||
"are parsed")
|
||||
|
@ -37,7 +37,6 @@ from ._src.util import partial, prod, unzip2
|
||||
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
|
||||
from .lib import xla_bridge
|
||||
from .interpreters import xla
|
||||
from .experimental import maps
|
||||
from .experimental.maps import mesh
|
||||
|
||||
|
||||
@ -1062,13 +1061,11 @@ def with_and_without_mesh(f):
|
||||
old_spmd_lowering_flag = False
|
||||
def set_spmd_lowering_flag(val: bool):
|
||||
global old_spmd_lowering_flag
|
||||
maps.make_xmap_callable.cache_clear()
|
||||
old_spmd_lowering_flag = maps.EXPERIMENTAL_SPMD_LOWERING
|
||||
maps.EXPERIMENTAL_SPMD_LOWERING = val
|
||||
old_spmd_lowering_flag = config.experimental_xmap_spmd_lowering
|
||||
config.update('experimental_xmap_spmd_lowering', val)
|
||||
|
||||
def restore_spmd_lowering_flag():
|
||||
maps.make_xmap_callable.cache_clear()
|
||||
maps.EXPERIMENTAL_SPMD_LOWERING = old_spmd_lowering_flag
|
||||
config.update('experimental_xmap_spmd_lowering', old_spmd_lowering_flag)
|
||||
|
||||
class _cached_property:
|
||||
null = object()
|
||||
|
@ -216,13 +216,10 @@ class SPMDTestMixin:
|
||||
if jtu.device_under_test() not in ['tpu', 'gpu']:
|
||||
raise SkipTest
|
||||
super().setUp()
|
||||
jax.experimental.maps.make_xmap_callable.cache_clear()
|
||||
self.old_lowering_flag = jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING
|
||||
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
|
||||
jtu.set_spmd_lowering_flag(True)
|
||||
|
||||
def tearDown(self):
|
||||
jax.experimental.maps.make_xmap_callable.cache_clear()
|
||||
jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = self.old_lowering_flag
|
||||
jtu.restore_spmd_lowering_flag()
|
||||
|
||||
|
||||
class XMapTest(XMapTestCase):
|
||||
@ -372,7 +369,7 @@ class XMapTest(XMapTestCase):
|
||||
(pxla.Chunked([2]), pxla.NoSharding(), pxla.NoSharding()))
|
||||
self.assertEqual(y[0].sharding_spec.mesh_mapping,
|
||||
(pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2),) * (len(mesh) - 2))
|
||||
if maps.EXPERIMENTAL_SPMD_LOWERING:
|
||||
if config.experimental_xmap_spmd_lowering:
|
||||
hlo = jax.xla_computation(f)(x).as_hlo_text()
|
||||
# Make sure that there are non-partial sharding specs in the HLO
|
||||
self.assertRegex(hlo, r"sharding={devices=\[[0-9,]+\][0-9,]+}")
|
||||
@ -745,7 +742,7 @@ class NewPrimitiveTest(XMapTestCase):
|
||||
|
||||
@jtu.with_and_without_mesh
|
||||
def testGather(self, mesh, axis_resources):
|
||||
if axis_resources and not jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING:
|
||||
if axis_resources and not config.experimental_xmap_spmd_lowering:
|
||||
raise SkipTest("pgather over mesh axes without SPMD lowering not implemented")
|
||||
x = jnp.arange(12, dtype=np.float32).reshape((4, 3))
|
||||
y = jnp.arange(35).reshape((5, 7)) % 3
|
||||
|
Loading…
x
Reference in New Issue
Block a user