mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[sharding_in_types] Default axis_types to Auto
for all axis_names if user does not set any AxisType. Also resolve some TODOs now that we have a way for user to set the mesh.
PiperOrigin-RevId: 704944255
This commit is contained in:
parent
b5e4fd161d
commit
41f490aef4
@ -1626,10 +1626,8 @@ def get_sharding(sharding, ndim):
|
||||
return _maybe_modify_sharding(sharding)
|
||||
|
||||
context_mesh = mesh_lib.get_abstract_mesh()
|
||||
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
|
||||
# code.
|
||||
if not context_mesh:
|
||||
return None
|
||||
return RuntimeError("Please set the mesh via `jax.set_mesh` API.")
|
||||
assert sharding is None
|
||||
return NamedSharding(context_mesh, P(*[None] * ndim))
|
||||
|
||||
@ -1692,7 +1690,7 @@ class ShapedArray(UnshapedArray):
|
||||
self.dtype.name)
|
||||
dt_str = dt_str.replace('void', 'float0')
|
||||
if hasattr(self, 'sharding') and self.sharding is not None:
|
||||
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
|
||||
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec) # type: ignore
|
||||
return f'{dt_str}[{shapestr}]'
|
||||
else:
|
||||
shapestr = ','.join(map(str, self.shape))
|
||||
@ -2658,16 +2656,10 @@ def _check_call(ctx_factory, prim, in_atoms, params):
|
||||
return aval
|
||||
for v, x in zip(call_jaxpr.invars, in_atoms):
|
||||
if not typecompat(substitute(v.aval), x.aval):
|
||||
# TODO(yashkatariya): Remove this once numpy array's aval has a sharding
|
||||
# on it.
|
||||
if (config.sharding_in_types.value and isinstance(x, Literal) and
|
||||
v.aval.sharding is not None and x.val.ndim == 0):
|
||||
pass
|
||||
else:
|
||||
# TODO(mattjj): vars in error message are confusing b/c of Var.__repr__
|
||||
raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type "
|
||||
f"{x.aval} to jaxpr expecting type "
|
||||
f"{substitute(v.aval)}")
|
||||
# TODO(mattjj): vars in error message are confusing b/c of Var.__repr__
|
||||
raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type "
|
||||
f"{x.aval} to jaxpr expecting type "
|
||||
f"{substitute(v.aval)}")
|
||||
env[v] = x if type(x) is Var else x.val
|
||||
|
||||
_check_jaxpr(ctx_factory, call_jaxpr)
|
||||
|
@ -111,8 +111,6 @@ class AxisTypes(enum.Enum):
|
||||
return self.name
|
||||
|
||||
def axis_names_to_types(axis_types) -> dict[str, AxisTypes]:
|
||||
if axis_types is None:
|
||||
return {}
|
||||
d = {}
|
||||
for t, names in axis_types.items():
|
||||
if isinstance(names, tuple):
|
||||
@ -179,7 +177,7 @@ class Mesh(contextlib.ContextDecorator):
|
||||
|
||||
devices: np.ndarray
|
||||
axis_names: tuple[MeshAxisName, ...]
|
||||
axis_types: MeshAxisType | None
|
||||
axis_types: MeshAxisType
|
||||
|
||||
def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
|
||||
axis_names: str | Sequence[MeshAxisName], *,
|
||||
@ -199,9 +197,9 @@ class Mesh(contextlib.ContextDecorator):
|
||||
f"devices.ndim == {devices.ndim} and "
|
||||
f"len(axis_names) == {len(axis_names)}.")
|
||||
|
||||
# TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
|
||||
axis_types_tuple = (None if axis_types is None else
|
||||
tuple(axis_types.items()))
|
||||
axis_types = ({AxisTypes.Auto: axis_names} if axis_types is None else
|
||||
axis_types)
|
||||
axis_types_tuple = tuple(axis_types.items())
|
||||
key = (axis_names, devices.shape, tuple(devices.flat), axis_types_tuple)
|
||||
val = _mesh_object_dict.get(key, None)
|
||||
if val is not None:
|
||||
@ -337,7 +335,7 @@ class Mesh(contextlib.ContextDecorator):
|
||||
def _repr(self):
|
||||
if self.empty:
|
||||
return "Mesh(device_ids=[], axis_names=())"
|
||||
atr = '' if self.axis_types is None else f", axis_types={self.axis_types}"
|
||||
atr = f", axis_types={self.axis_types}"
|
||||
return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r}{atr})"
|
||||
|
||||
def __repr__(self):
|
||||
@ -378,14 +376,13 @@ class AbstractMesh:
|
||||
def __init__(self, shape_tuple: tuple[tuple[str, int], ...], *,
|
||||
axis_types: MeshAxisType | None = None):
|
||||
self.shape_tuple = shape_tuple
|
||||
self.axis_types = axis_types
|
||||
if self.shape_tuple:
|
||||
self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple))
|
||||
else:
|
||||
self._axis_names, self._axis_sizes = (), ()
|
||||
# TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
|
||||
self._axis_types_tuple = (None if axis_types is None else
|
||||
tuple(axis_types.items()))
|
||||
self.axis_types = ({AxisTypes.Auto: self._axis_names} if axis_types is None
|
||||
else axis_types)
|
||||
self._axis_types_tuple = tuple(self.axis_types.items())
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.shape_tuple, self._axis_types_tuple))
|
||||
@ -399,7 +396,7 @@ class AbstractMesh:
|
||||
self._axis_types_tuple == other._axis_types_tuple)
|
||||
|
||||
def __repr__(self):
|
||||
atr = '' if self.axis_types is None else f", axis_types={self.axis_types}"
|
||||
atr = f", axis_types={self.axis_types}"
|
||||
return f"AbstractMesh({self.shape_tuple}{atr})"
|
||||
|
||||
@property
|
||||
@ -432,26 +429,18 @@ class AbstractMesh:
|
||||
|
||||
@functools.cached_property
|
||||
def _are_all_axes_collective(self) -> bool:
|
||||
if self.axis_types is None:
|
||||
return False
|
||||
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _are_all_axes_auto(self) -> bool:
|
||||
if self.axis_types is None:
|
||||
return False
|
||||
return all(t == AxisTypes.Auto for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _any_axis_collective(self) -> bool:
|
||||
if self.axis_types is None:
|
||||
return False
|
||||
return any(t == AxisTypes.Collective for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _any_axis_auto(self) -> bool:
|
||||
if self.axis_types is None:
|
||||
return False
|
||||
return any(t == AxisTypes.Auto for t in self.axis_types.keys())
|
||||
|
||||
@property
|
||||
@ -494,8 +483,6 @@ def _raise_value_error(name):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_abstract_mesh(mesh: AbstractMesh):
|
||||
if mesh is not None and mesh.axis_types is None:
|
||||
raise RuntimeError('Please set the AxisTypes of Mesh.')
|
||||
prev_val = jax_config.abstract_mesh_context_manager.swap_local(mesh)
|
||||
try:
|
||||
yield
|
||||
|
@ -698,9 +698,6 @@ def get_abstract_mesh_from_avals(in_avals):
|
||||
return None
|
||||
m = None
|
||||
for a in in_avals:
|
||||
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
|
||||
if a.sharding is None: # type: ignore
|
||||
continue
|
||||
if m is not None and m != a.sharding.mesh:
|
||||
raise ValueError(
|
||||
f'Mesh for all inputs should be equal. Got one mesh: {m} and'
|
||||
@ -1788,9 +1785,7 @@ def _pjit_lower(
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
pgle_profiler: profiler.PGLEProfiler | None):
|
||||
if config.sharding_in_types.value:
|
||||
cur_mesh = mesh_lib.get_concrete_mesh()
|
||||
mesh = cur_mesh if isinstance(cur_mesh, mesh_lib.Mesh) else None
|
||||
api_name = 'jit'
|
||||
mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
|
||||
else:
|
||||
mesh, api_name = ((resource_env.physical_mesh, 'pjit')
|
||||
if resource_env is not None else (None, 'jit'))
|
||||
|
@ -5483,6 +5483,24 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertIn('@Sharding', f.lower(arr).as_text())
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'), {mesh_lib.AxisTypes.Auto: ('x', 'y')})
|
||||
def test_only_auto(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None)))
|
||||
|
||||
@jax.jit
|
||||
def f(x, x2):
|
||||
y = x * 2
|
||||
self.assertEqual(y.sharding.spec, P(P.UNCONSTRAINED, None))
|
||||
z = jnp.sin(y)
|
||||
self.assertEqual(z.sharding.spec, P(P.UNCONSTRAINED, None))
|
||||
a = z @ x2
|
||||
self.assertEqual(a.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED))
|
||||
return a
|
||||
|
||||
out = f(arr, arr.T)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
|
||||
|
||||
def test_auto_user(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'),
|
||||
axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')})
|
||||
|
Loading…
x
Reference in New Issue
Block a user