mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[sharding_in_types] Upstream changes from defaulting sharding_in_types config to True experiment. There aren't a lot of failures in TGP but we can atleast upstream these changes until we work on the failures.
PiperOrigin-RevId: 720639755
This commit is contained in:
parent
7a4a53ad9e
commit
8f248fe626
@ -817,8 +817,8 @@ already_configured_with_absl = False
|
||||
trace_state = config_ext.Config(None, include_in_jit_key=True)
|
||||
axis_env_state = config_ext.Config((), include_in_jit_key=True)
|
||||
mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
|
||||
device_context = config_ext.Config((), include_in_jit_key=True)
|
||||
abstract_mesh_context_manager = config_ext.Config(None, include_in_jit_key=True)
|
||||
device_context = config_ext.Config(None, include_in_jit_key=True)
|
||||
compute_on_context_manager = config_ext.Config(None, include_in_jit_key=True)
|
||||
xla_metadata_context_manager = config_ext.Config(None, include_in_jit_key=True)
|
||||
|
||||
|
@ -591,8 +591,10 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True) -> list[
|
||||
|
||||
def check_avals_context_mesh(avals, prim_name):
|
||||
if config.sharding_in_types.value:
|
||||
for a in avals:
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
for a in avals:
|
||||
if a.sharding.mesh.empty or cur_mesh.empty:
|
||||
continue
|
||||
if a.sharding.mesh != cur_mesh:
|
||||
raise ValueError(
|
||||
f"For primitive {prim_name}, context mesh {cur_mesh} should match"
|
||||
@ -1778,11 +1780,11 @@ def get_sharding(sharding, ndim):
|
||||
"Length of sharding.spec must be equal to aval's ndim. Got"
|
||||
f" sharding.spec {out_s.spec} and aval.ndim {ndim}")
|
||||
else:
|
||||
context_mesh = mesh_lib.get_abstract_mesh()
|
||||
if not context_mesh:
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
if cur_mesh.empty:
|
||||
raise RuntimeError("Please set the mesh via `jax.set_mesh` API.")
|
||||
assert sharding is None
|
||||
out_s = NamedSharding(context_mesh, P(*[None] * ndim))
|
||||
out_s = NamedSharding(cur_mesh, P(*[None] * ndim))
|
||||
if not isinstance(out_s.mesh, mesh_lib.AbstractMesh):
|
||||
raise ValueError("Mesh of an aval must be an AbstractMesh. "
|
||||
f"Got {out_s.mesh} of type {type(out_s.mesh)}")
|
||||
@ -1924,6 +1926,11 @@ class DShapedArray(UnshapedArray):
|
||||
weak_type = self.weak_type
|
||||
return DShapedArray(shape, dtype, weak_type)
|
||||
|
||||
@property
|
||||
def sharding(self):
|
||||
from jax._src.sharding_impls import NamedSharding # type: ignore
|
||||
return NamedSharding(mesh_lib.AbstractMesh(()), P())
|
||||
|
||||
def _len(self, tracer):
|
||||
return self.shape[0]
|
||||
|
||||
|
@ -1663,8 +1663,10 @@ def lower_jaxpr_to_fun(
|
||||
flat_args = [
|
||||
replicate_trailing_dims(entry_lowering_ctx, o, a)
|
||||
if (a is not core.abstract_token and
|
||||
dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # pytype: disable=attribute-error
|
||||
for o, s, a in zip(flat_args, ir_arg_shardings, input_avals)
|
||||
dtypes.issubdtype(a.dtype, dtypes.extended) and
|
||||
(s is None or all_unconstrained(rs, a))) else o # pytype: disable=attribute-error
|
||||
for o, s, a, rs in zip(flat_args, ir_arg_shardings, input_avals,
|
||||
arg_shardings) # type: ignore
|
||||
]
|
||||
|
||||
_, token_args, unflattened_args = util.split_list(
|
||||
@ -1717,8 +1719,10 @@ def lower_jaxpr_to_fun(
|
||||
flat_outputs = [
|
||||
replicate_trailing_dims(entry_lowering_ctx, o, a)
|
||||
if (a is not core.abstract_token and
|
||||
dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # pytype: disable=attribute-error
|
||||
for o, s, a in zip(flat_outputs, ir_result_shardings, output_avals)
|
||||
dtypes.issubdtype(a.dtype, dtypes.extended) and
|
||||
(s is None or all_unconstrained(rs, a))) else o # pytype: disable=attribute-error
|
||||
for o, s, a, rs in zip(flat_outputs, ir_result_shardings, output_avals,
|
||||
result_shardings) # type: ignore
|
||||
]
|
||||
|
||||
func_dialect.return_(flat_outputs)
|
||||
@ -1917,7 +1921,8 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
source_info = eqn.source_info.replace(
|
||||
name_stack=name_stack + eqn.source_info.name_stack)
|
||||
loc = _source_info_to_location(ctx, eqn.primitive, source_info)
|
||||
with source_info_util.user_context(eqn.source_info.traceback), loc:
|
||||
with (source_info_util.user_context(eqn.source_info.traceback), loc,
|
||||
eqn.ctx.manager):
|
||||
override_rule = get_override_lowering_rule(eqn.primitive)
|
||||
platform_rules: dict[str, LoweringRule] = {}
|
||||
default_rule: LoweringRule | None = None
|
||||
|
@ -2102,7 +2102,8 @@ def _get_num_devices(
|
||||
for s in shardings:
|
||||
if isinstance(s, UnspecifiedValue):
|
||||
continue
|
||||
elif isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
|
||||
elif (isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh) and
|
||||
not s.mesh.empty):
|
||||
if abstract_mesh is not None and abstract_mesh != s.mesh:
|
||||
raise ValueError("AbstractMesh should be the same across all "
|
||||
f"shardings. Got {abstract_mesh} and {s.mesh}")
|
||||
@ -2158,6 +2159,9 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts,
|
||||
donated_invars, out_shardings, out_layouts)
|
||||
|
||||
def _concretize_abstract_out_shardings(shardings, avals, device_assignment):
|
||||
if len(device_assignment) == 1:
|
||||
return shardings
|
||||
|
||||
np_dev = np.vectorize(lambda i: device_assignment[i],
|
||||
otypes=[object])(np.arange(len(device_assignment)))
|
||||
|
||||
@ -2170,6 +2174,9 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment):
|
||||
out = []
|
||||
for s, a in zip(shardings, avals):
|
||||
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
|
||||
if a.sharding.mesh.empty:
|
||||
out.append(s)
|
||||
else:
|
||||
spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp
|
||||
for sp in a.sharding.spec])
|
||||
if a.sharding.mesh._any_axis_auto else a.sharding.spec)
|
||||
@ -2792,7 +2799,10 @@ def _maybe_get_and_check_out_shardings(
|
||||
if (aval is not core.abstract_token and
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
xla_s = sharding_impls.logical_sharding(aval, xla_s)
|
||||
try:
|
||||
new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig)) # type: ignore
|
||||
except:
|
||||
new_out_shardings.append(xla_s)
|
||||
else:
|
||||
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim)
|
||||
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # pytype: disable=attribute-error
|
||||
|
@ -229,10 +229,12 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
msg.format(', '.join(str(x) for x in xs_flat
|
||||
if not hasattr(x, 'shape')))) from err
|
||||
|
||||
xs_avals = [core.get_aval(x) for x in xs_flat]
|
||||
|
||||
if (config.sharding_in_types.value and
|
||||
not all(x.aval.sharding.spec[0] is None for x in xs_flat)):
|
||||
not all(a.sharding.spec[0] is None for a in xs_avals)):
|
||||
raise ValueError('0th dimension of all xs should be replicated. Got '
|
||||
f'{", ".join(str(x.aval.sharding.spec) for x in xs_flat)}')
|
||||
f'{", ".join(str(a.sharding.spec) for a in xs_avals)}')
|
||||
|
||||
if length is not None:
|
||||
try:
|
||||
@ -270,7 +272,6 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
stacked_y = tree_map(stack, *maybe_reversed(ys))
|
||||
return carry, stacked_y
|
||||
|
||||
xs_avals = [core.get_aval(x) for x in xs_flat]
|
||||
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
||||
dbg_body = api_util.tracing_debug_info("scan", f, (init, xs), {})
|
||||
|
||||
|
@ -64,6 +64,7 @@ from jax._src.lax.utils import (
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.mesh import get_abstract_mesh
|
||||
from jax._src.sharding_impls import (PmapSharding, NamedSharding,
|
||||
PartitionSpec as P, canonicalize_sharding)
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape
|
||||
@ -179,7 +180,7 @@ def _broadcast_shapes_uncached(*shapes):
|
||||
# Raise ValueError here for backward compatibility.
|
||||
raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") from err
|
||||
|
||||
def broadcast_shardings(*avals) -> NamedSharding:
|
||||
def broadcast_shardings(*avals):
|
||||
fst, *rst = avals
|
||||
if not rst:
|
||||
return fst.sharding
|
||||
@ -585,11 +586,8 @@ def _convert_element_type(
|
||||
new_dtype = np.dtype(new_dtype)
|
||||
new_dtype = dtypes.dtype(new_dtype, canonicalize=True)
|
||||
|
||||
if (config.sharding_in_types.value and sharding is None and
|
||||
isinstance(operand, Array)):
|
||||
sharding = operand.aval.sharding
|
||||
|
||||
sharding = canonicalize_sharding(sharding, check_mesh_consistency=False) # type: ignore
|
||||
if sharding is not None and not isinstance(sharding, Sharding):
|
||||
raise ValueError(f'{sharding=} must be an instance of jax.sharding.Sharding')
|
||||
|
||||
if (warn_on_complex_to_real_cast and
|
||||
dtypes.issubdtype(old_dtype, np.complexfloating) and
|
||||
@ -1373,6 +1371,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
|
||||
raise NotImplementedError(
|
||||
'`out_sharding` argument of `dot_general` only supports NamedSharding '
|
||||
'instances. Please file a bug if this is not enough for your use case.')
|
||||
out_sharding = canonicalize_sharding(out_sharding)
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
cdims = (api_util._ensure_index_tuple(lhs_contract),
|
||||
api_util._ensure_index_tuple(rhs_contract))
|
||||
@ -1933,13 +1932,13 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
|
||||
dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
|
||||
fill_value = _convert_element_type(fill_value, dtype, weak_type)
|
||||
if (sharding is not None and not isinstance(sharding, PmapSharding) and
|
||||
isinstance(fill_value, array.ArrayImpl) and
|
||||
not config.sharding_in_types.value):
|
||||
isinstance(fill_value, array.ArrayImpl) and sharding.is_concrete):
|
||||
broadcast_shape = sharding.shard_shape(shape)
|
||||
shard = broadcast(fill_value, broadcast_shape)
|
||||
return array.make_array_from_callback(shape, sharding, lambda _: shard)
|
||||
|
||||
if config.sharding_in_types.value and sharding is not None:
|
||||
if (config.sharding_in_types.value and sharding is not None and
|
||||
not sharding.is_concrete):
|
||||
return broadcast(fill_value, shape, sharding=sharding)
|
||||
else:
|
||||
return broadcast(fill_value, shape)
|
||||
@ -2150,7 +2149,7 @@ def full_like(x: ArrayLike | DuckTypedArray,
|
||||
return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr]
|
||||
|
||||
if (config.sharding_in_types.value and sharding is None and shape is None and
|
||||
isinstance(x, Array)):
|
||||
isinstance(x, core.Tracer)):
|
||||
sharding = x.aval.sharding
|
||||
else:
|
||||
# If `x` has a sharding but no `_committed` attribute
|
||||
@ -3184,6 +3183,13 @@ def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type,
|
||||
|
||||
def _convert_element_type_sharding_rule(operand, *, new_dtype, weak_type,
|
||||
sharding):
|
||||
if sharding is None:
|
||||
return operand.sharding
|
||||
if sharding.is_concrete:
|
||||
if isinstance(sharding, NamedSharding):
|
||||
return NamedSharding(sharding.mesh.abstract_mesh, sharding.spec)
|
||||
else:
|
||||
return None
|
||||
return sharding
|
||||
|
||||
def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type,
|
||||
@ -3268,7 +3274,7 @@ convert_element_type_p = Primitive('convert_element_type')
|
||||
def _convert_element_type_bind_with_trace(trace, args, params):
|
||||
sharding = params['sharding']
|
||||
operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params)
|
||||
if sharding is not None and not config.sharding_in_types.value:
|
||||
if sharding is not None and sharding.is_concrete:
|
||||
with core.set_current_trace(trace):
|
||||
operand = pjit.with_sharding_constraint(operand, sharding)
|
||||
return operand
|
||||
@ -3303,8 +3309,6 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
|
||||
aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype))
|
||||
out = mlir.convert_hlo(ctx, operand, aval_in, aval_out)
|
||||
if config.sharding_in_types.value:
|
||||
if sharding is not None:
|
||||
assert aval_out.sharding == sharding
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
|
||||
@ -3734,6 +3738,10 @@ def _dot_batch_rule(
|
||||
else:
|
||||
rhs_shape = np.shape(rhs)
|
||||
if out_sharding is not None:
|
||||
cur_mesh = get_abstract_mesh()
|
||||
if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual:
|
||||
out_sharding = None
|
||||
else:
|
||||
raise NotImplementedError("vmap with out_sharding is not supported. "
|
||||
"Please open an issue.")
|
||||
batched_out = invoke_prim(
|
||||
@ -4433,8 +4441,13 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
|
||||
dyn_limits.append(bound)
|
||||
new_shape = (stacked_size,) + _merge_dyn_shape(shape, dyn_limits)
|
||||
if sharding is not None:
|
||||
raise NotImplementedError('Implement broadcast_in_dim_batch_rule')
|
||||
result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions)
|
||||
if sharding.mesh._are_all_axes_auto or sharding.mesh._are_all_axes_manual:
|
||||
sharding = None
|
||||
else:
|
||||
raise NotImplementedError('Implement sharding support for '
|
||||
'broadcast_in_dim_batch_rule')
|
||||
result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions,
|
||||
sharding=sharding)
|
||||
out_ragged_axes = [idx+1 for idx, s in enumerate(shape) if s is None]
|
||||
out_bdim = batching.make_batch_axis(
|
||||
result.ndim, 0, zip(out_ragged_axes, out_ragged_sizes))
|
||||
@ -5108,8 +5121,9 @@ def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding):
|
||||
return [reshape(t, operand.aval.shape)]
|
||||
else:
|
||||
if config.sharding_in_types.value:
|
||||
t_s = operand.sharding.with_spec(
|
||||
tuple(map(str, np.take(operand.aval.sharding.spec, dimensions))))
|
||||
t_s = operand.aval.sharding.with_spec(
|
||||
tuple(map(lambda s: s if s is None else str(s),
|
||||
np.take(operand.aval.sharding.spec, dimensions))))
|
||||
else:
|
||||
t_s = None
|
||||
return [transpose(reshape(t, np.take(operand.aval.shape, dimensions),
|
||||
@ -5119,13 +5133,18 @@ def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding):
|
||||
def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions,
|
||||
sharding):
|
||||
if sharding is not None:
|
||||
raise NotImplementedError
|
||||
if sharding.mesh._are_all_axes_manual or sharding.mesh._are_all_axes_auto:
|
||||
sharding = None
|
||||
else:
|
||||
raise NotImplementedError('reshape batch sharding support')
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
operand = batching.moveaxis(operand, bdim, 0)
|
||||
if dimensions is not None:
|
||||
dimensions = (0,) + tuple(np.add(1, dimensions))
|
||||
return reshape(operand, operand.shape[:1] + new_sizes, dimensions), 0
|
||||
out = reshape(operand, operand.shape[:1] + new_sizes, dimensions,
|
||||
sharding=sharding)
|
||||
return out, 0
|
||||
|
||||
|
||||
def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions, sharding):
|
||||
@ -6689,16 +6708,18 @@ _zeros: Callable = partial(full_like, fill_value=0)
|
||||
|
||||
def _zero(x):
|
||||
if config.sharding_in_types.value:
|
||||
x_aval = core.get_aval(x)
|
||||
return full_like(x, shape=(), fill_value=0,
|
||||
sharding=x.aval.sharding.with_spec(P())) # type: ignore
|
||||
sharding=x_aval.sharding.with_spec(P())) # type: ignore
|
||||
return full_like(x, shape=(), fill_value=0)
|
||||
|
||||
_ones: Callable = partial(full_like, fill_value=1)
|
||||
|
||||
def _one(x):
|
||||
if config.sharding_in_types.value:
|
||||
x_aval = core.get_aval(x)
|
||||
return full_like(x, shape=(), fill_value=1,
|
||||
sharding=x.aval.sharding.with_spec(P()))
|
||||
sharding=x_aval.sharding.with_spec(P()))
|
||||
return full_like(x, shape=(), fill_value=1)
|
||||
|
||||
_twos: Callable = partial(full_like, fill_value=2)
|
||||
|
@ -49,11 +49,10 @@ def _get_array_abstraction_level(a): return a.array_abstraction_level
|
||||
|
||||
def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
|
||||
if config.sharding_in_types.value:
|
||||
if rule is None:
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual: # type: ignore
|
||||
if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual:
|
||||
return None if num_out is None else [None] * num_out
|
||||
else:
|
||||
if rule is None:
|
||||
raise ValueError(
|
||||
f'sharding rule for {prim.name} is not implemented. Please file a'
|
||||
' bug at https://github.com/jax-ml/jax/issues. You can work around'
|
||||
|
@ -120,6 +120,16 @@ def axis_types_to_names(name_to_type: dict[str, AxisTypes]):
|
||||
d[t].append(n)
|
||||
return {t: ns[0] if len(ns) == 1 else tuple(ns) for t, ns in d.items()}
|
||||
|
||||
def to_axis_types_tuple(axis_types):
|
||||
out = []
|
||||
for t, names in axis_types.items():
|
||||
if isinstance(names, tuple):
|
||||
new_names = names[0] if len(names) == 1 else names
|
||||
else:
|
||||
new_names = names
|
||||
out.append((t, new_names))
|
||||
return tuple(out)
|
||||
|
||||
|
||||
_mesh_object_dict = {} # type: ignore
|
||||
|
||||
@ -200,7 +210,7 @@ class Mesh(contextlib.ContextDecorator):
|
||||
|
||||
axis_types = ({AxisTypes.Auto: axis_names} if axis_types is None else
|
||||
axis_types)
|
||||
axis_types_tuple = tuple(axis_types.items())
|
||||
axis_types_tuple = to_axis_types_tuple(axis_types)
|
||||
if len(axis_names_to_types(axis_types).keys()) != len(axis_names):
|
||||
raise ValueError(
|
||||
"Number of axis names in axis_types should match the number of"
|
||||
@ -405,7 +415,7 @@ class AbstractMesh:
|
||||
self._axis_names, self._axis_sizes = (), ()
|
||||
self.axis_types = ({AxisTypes.Auto: self._axis_names}
|
||||
if axis_types is None else axis_types)
|
||||
self._axis_types_tuple = tuple(self.axis_types.items())
|
||||
self._axis_types_tuple = to_axis_types_tuple(self.axis_types)
|
||||
if len(self._name_to_type.keys()) != len(self._axis_names):
|
||||
raise ValueError(
|
||||
"Number of axis names in axis_types should match the number of"
|
||||
@ -424,7 +434,8 @@ class AbstractMesh:
|
||||
self._axis_types_tuple == other._axis_types_tuple)
|
||||
|
||||
def __repr__(self):
|
||||
mesh_repr = ", ".join(f"'{n}': {v}" for n, v in self.shape_tuple)
|
||||
mesh_repr = (", ".join(f"'{n}': {v}" for n, v in self.shape_tuple)
|
||||
if self.shape_tuple else "()")
|
||||
atr = f", axis_types={self.axis_types}"
|
||||
return f"AbstractMesh({mesh_repr}{atr})"
|
||||
|
||||
@ -537,8 +548,11 @@ def set_abstract_mesh(mesh: AbstractMesh):
|
||||
finally:
|
||||
jax_config.abstract_mesh_context_manager.set_local(prev_val)
|
||||
|
||||
empty_abstract_mesh = AbstractMesh(())
|
||||
|
||||
def get_abstract_mesh():
|
||||
return jax_config.abstract_mesh_context_manager.value
|
||||
val = jax_config.abstract_mesh_context_manager.value
|
||||
return empty_abstract_mesh if val is None else val
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
@ -5554,8 +5554,9 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
||||
# array([1, 2, 3])
|
||||
weak_type = dtype is None and dtypes.is_weakly_typed(object)
|
||||
if (config.sharding_in_types.value and device is None and
|
||||
isinstance(object, Array)):
|
||||
isinstance(object, core.Tracer)):
|
||||
sharding = object.aval.sharding
|
||||
sharding = None if sharding.mesh.empty else sharding
|
||||
else:
|
||||
sharding = canonicalize_device_to_sharding(device) # type: ignore
|
||||
|
||||
|
@ -82,11 +82,6 @@ def promote_dtypes(*args: ArrayLike) -> list[Array]:
|
||||
else:
|
||||
to_dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment]
|
||||
if config.sharding_in_types.value:
|
||||
return [lax._convert_element_type(x, to_dtype, weak_type,
|
||||
getattr(x, "sharding", None))
|
||||
for x in args]
|
||||
else:
|
||||
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
|
||||
|
||||
|
||||
|
@ -622,8 +622,8 @@ def _infer_params_impl(
|
||||
attr_token = _attr_token(flat_fun, in_type)
|
||||
|
||||
abstract_mesh = (
|
||||
get_abstract_mesh_from_avals(in_type)
|
||||
if not mesh_lib.get_abstract_mesh() else mesh_lib.get_abstract_mesh())
|
||||
get_abstract_mesh_from_avals(in_avals)
|
||||
if mesh_lib.get_abstract_mesh().empty else mesh_lib.get_abstract_mesh())
|
||||
with mesh_lib.set_abstract_mesh(abstract_mesh):
|
||||
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
|
||||
flat_fun, in_type, attr_token, dbg,
|
||||
@ -677,13 +677,16 @@ def get_abstract_mesh_from_avals(in_avals):
|
||||
return None
|
||||
m = None
|
||||
for a in in_avals:
|
||||
if a is core.abstract_token:
|
||||
continue
|
||||
if a.sharding.mesh.empty: # type: ignore
|
||||
continue
|
||||
if m is not None and m != a.sharding.mesh:
|
||||
raise ValueError(
|
||||
f'Mesh for all inputs should be equal. Got one mesh: {m} and'
|
||||
f' another mesh: {a.sharding.mesh}')
|
||||
m = a.sharding.mesh # type: ignore
|
||||
assert isinstance(m, AbstractMesh)
|
||||
return m
|
||||
return AbstractMesh(()) if m is None else m
|
||||
|
||||
|
||||
class InferParamsCacheEntry:
|
||||
@ -722,6 +725,10 @@ def _infer_params(
|
||||
else:
|
||||
resource_env = None
|
||||
pjit_mesh = None
|
||||
if resource_env is not None and mesh_lib.get_concrete_mesh() is not None:
|
||||
raise ValueError(
|
||||
'Using `with mesh:` context manager and `jax.sharding.use_mesh`'
|
||||
' together is not allowed.')
|
||||
|
||||
dbg = tracing_debug_info(
|
||||
'jit', fun, args, kwargs, static_argnums=ji.static_argnums,
|
||||
@ -1561,6 +1568,8 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
|
||||
arg_s, committed = ((arg.sharding, getattr(arg, '_committed', True))
|
||||
if hasattr(arg, 'sharding') and arg.sharding is not None
|
||||
else (UNSPECIFIED, False))
|
||||
if isinstance(arg_s, NamedSharding) and arg_s.mesh.empty:
|
||||
arg_s, committed = UNSPECIFIED, False
|
||||
if isinstance(pjit_in_s, UnspecifiedValue):
|
||||
if isinstance(arg_s, UnspecifiedValue):
|
||||
resolved_in_shardings.append(arg_s)
|
||||
@ -1595,7 +1604,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
|
||||
'Please see the jax.Array migration guide for more information '
|
||||
'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. '
|
||||
f'Got arg shape: {arg.shape}, arg value: {arg}')
|
||||
if not isinstance(arg_s, UnspecifiedValue):
|
||||
if not isinstance(arg_s, UnspecifiedValue) and arg_s.is_concrete:
|
||||
# jax.jit does not allow resharding across different memory kinds even
|
||||
# if the argument is uncommitted. Use jax.device_put for those cases,
|
||||
# either outside or inside jax.jit.
|
||||
@ -1789,6 +1798,9 @@ def _pjit_lower(
|
||||
pgle_profiler: profiler.PGLEProfiler | None):
|
||||
util.test_event("pjit_lower")
|
||||
if config.sharding_in_types.value:
|
||||
if resource_env is not None:
|
||||
mesh, api_name = resource_env.physical_mesh, 'pjit'
|
||||
else:
|
||||
mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
|
||||
else:
|
||||
mesh, api_name = ((resource_env.physical_mesh, 'pjit')
|
||||
@ -2683,18 +2695,17 @@ batching.skippable_batchers[sharding_constraint_p] = lambda _: ()
|
||||
# TODO(yashkatariya): Make shardings optional.
|
||||
def mesh_cast(xs, out_shardings):
|
||||
x_flat, treedef = tree_flatten(xs)
|
||||
x_avals_flat = [core.shaped_abstractify(x) for x in x_flat]
|
||||
shardings_flat = flatten_axes("mesh_cast shardings", treedef, out_shardings)
|
||||
out_flat = [
|
||||
mesh_cast_p.bind(
|
||||
x, src_sharding=x_aval.sharding,
|
||||
dst_sharding=canonicalize_sharding(s, check_mesh_consistency=False))
|
||||
for x, x_aval, s in safe_zip(x_flat, x_avals_flat, shardings_flat)
|
||||
x, dst_sharding=canonicalize_sharding(s, check_mesh_consistency=False))
|
||||
for x, s in safe_zip(x_flat, shardings_flat)
|
||||
]
|
||||
return tree_unflatten(treedef, out_flat)
|
||||
|
||||
mesh_cast_p = core.Primitive('mesh_cast')
|
||||
def _mesh_cast_abstract_eval(aval, src_sharding, dst_sharding):
|
||||
def _mesh_cast_abstract_eval(aval, dst_sharding):
|
||||
src_sharding = aval.sharding
|
||||
if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple:
|
||||
raise ValueError(
|
||||
f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not'
|
||||
@ -2731,17 +2742,15 @@ def _mesh_cast_abstract_eval(aval, src_sharding, dst_sharding):
|
||||
return aval.update(sharding=dst_sharding)
|
||||
mesh_cast_p.def_abstract_eval(_mesh_cast_abstract_eval)
|
||||
|
||||
def _mesh_cast_impl(x, src_sharding, dst_sharding):
|
||||
return dispatch.apply_primitive(mesh_cast_p, x, src_sharding=src_sharding,
|
||||
dst_sharding=dst_sharding)
|
||||
def _mesh_cast_impl(x, dst_sharding):
|
||||
return dispatch.apply_primitive(mesh_cast_p, x, dst_sharding=dst_sharding)
|
||||
mesh_cast_p.def_impl(_mesh_cast_impl)
|
||||
|
||||
def _mesh_cast_transpose_rule(ct, _, src_sharding, dst_sharding):
|
||||
return [mesh_cast_p.bind(ct, src_sharding=dst_sharding,
|
||||
dst_sharding=src_sharding)]
|
||||
def _mesh_cast_transpose_rule(ct, x, dst_sharding):
|
||||
return [mesh_cast_p.bind(ct, dst_sharding=x.aval.sharding)]
|
||||
ad.deflinear2(mesh_cast_p, _mesh_cast_transpose_rule)
|
||||
|
||||
def _mesh_cast_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding):
|
||||
def _mesh_cast_hlo_lowering(ctx, x_node, *, dst_sharding):
|
||||
aval, = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
proto = (dst_sharding._to_sdy_sharding(aval.ndim)
|
||||
@ -2750,27 +2759,18 @@ def _mesh_cast_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding):
|
||||
return [mlir.lower_sharding_under_shit(ctx, x_node, aval_out, proto)]
|
||||
mlir.register_lowering(mesh_cast_p, _mesh_cast_hlo_lowering)
|
||||
|
||||
# TODO(yashkatariya): Comment this in after vmap ShiT tests are added.
|
||||
# def _mesh_cast_batcher(axis_data, vals_in, dims_in, src_sharding,
|
||||
# dst_sharding):
|
||||
# if axis_data.spmd_name is not None:
|
||||
# used = {n for ns in dst_sharding.spec
|
||||
# for n in (ns if isinstance(ns, tuple) else (ns,))}
|
||||
# if set(axis_data.spmd_name) & used:
|
||||
# raise ValueError(
|
||||
# f'vmap spmd_axis_name {axis_data.spmd_name} cannot '
|
||||
# f'appear in mesh_cast spec, but got spec {dst_sharding.spec}')
|
||||
# x, = vals_in
|
||||
# d, = dims_in
|
||||
def _mesh_cast_batcher(axis_data, vals_in, dims_in, dst_sharding):
|
||||
assert axis_data.spmd_name is None
|
||||
x, = vals_in
|
||||
d, = dims_in
|
||||
|
||||
# val = None if axis_data.spmd_name is None else axis_data.spmd_name
|
||||
# new_spec = PartitionSpec(*util.tuple_insert(dst_sharding.spec, d, val))
|
||||
# vmapped_dst_sharding = NamedSharding(dst_sharding.mesh, new_spec)
|
||||
# y = mesh_cast_p.bind(x, src_sharding=src_sharding,
|
||||
# dst_sharding=vmapped_dst_sharding)
|
||||
# return y, d
|
||||
# batching.fancy_primitive_batchers[mesh_cast_p] = _mesh_cast_batcher
|
||||
# batching.skippable_batchers[mesh_cast_p] = lambda _: ()
|
||||
val = None
|
||||
new_spec = PartitionSpec(*util.tuple_insert(dst_sharding.spec, d, val))
|
||||
vmapped_dst_sharding = NamedSharding(dst_sharding.mesh, new_spec)
|
||||
y = mesh_cast_p.bind(x, dst_sharding=vmapped_dst_sharding)
|
||||
return y, d
|
||||
batching.fancy_primitive_batchers[mesh_cast_p] = _mesh_cast_batcher
|
||||
batching.skippable_batchers[mesh_cast_p] = lambda _: ()
|
||||
|
||||
# -------------------- reshard ------------------------------------
|
||||
|
||||
@ -2782,13 +2782,13 @@ def reshard(xs, out_shardings):
|
||||
for x, x_aval, s in safe_zip(x_flat, x_avals_flat, shardings_flat):
|
||||
ds = canonicalize_sharding(s)
|
||||
ds = ds.with_spec(ds.spec._normalized_spec(x_aval.ndim)) # type: ignore
|
||||
out_flat.append(reshard_p.bind(x, src_sharding=x_aval.sharding,
|
||||
dst_sharding=ds))
|
||||
out_flat.append(reshard_p.bind(x, dst_sharding=ds))
|
||||
return tree_unflatten(treedef, out_flat)
|
||||
|
||||
reshard_p = core.Primitive('reshard')
|
||||
|
||||
def _reshard_abstract_eval(aval, src_sharding, dst_sharding):
|
||||
def _reshard_abstract_eval(aval, dst_sharding):
|
||||
src_sharding = aval.sharding
|
||||
if src_sharding.mesh.abstract_mesh != dst_sharding.mesh.abstract_mesh:
|
||||
raise ValueError(
|
||||
f'Mesh of the input {src_sharding.mesh.abstract_mesh} does not'
|
||||
@ -2797,17 +2797,15 @@ def _reshard_abstract_eval(aval, src_sharding, dst_sharding):
|
||||
return aval.update(sharding=dst_sharding)
|
||||
reshard_p.def_abstract_eval(_reshard_abstract_eval)
|
||||
|
||||
def _reshard_impl(x, src_sharding, dst_sharding):
|
||||
return dispatch.apply_primitive(reshard_p, x, src_sharding=src_sharding,
|
||||
dst_sharding=dst_sharding)
|
||||
def _reshard_impl(x, dst_sharding):
|
||||
return dispatch.apply_primitive(reshard_p, x, dst_sharding=dst_sharding)
|
||||
reshard_p.def_impl(_reshard_impl)
|
||||
|
||||
def _reshard_transpose_rule(ct, _, src_sharding, dst_sharding):
|
||||
return [reshard_p.bind(ct, src_sharding=dst_sharding,
|
||||
dst_sharding=src_sharding)]
|
||||
def _reshard_transpose_rule(ct, x, dst_sharding):
|
||||
return [reshard_p.bind(ct, dst_sharding=x.aval.sharding)]
|
||||
ad.deflinear2(reshard_p, _reshard_transpose_rule)
|
||||
|
||||
def _reshard_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding):
|
||||
def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding):
|
||||
aval, = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
proto = (dst_sharding._to_sdy_sharding(aval.ndim)
|
||||
|
@ -138,6 +138,10 @@ class Sharding:
|
||||
#############################################################################
|
||||
# Default implementations below that all subclasses will inherit.
|
||||
|
||||
@property
|
||||
def is_concrete(self) -> bool:
|
||||
return True
|
||||
|
||||
@functools.cached_property
|
||||
def addressable_devices(self) -> set[Device]:
|
||||
"""The set of devices in the :class:`Sharding` that are addressable by the
|
||||
|
@ -391,6 +391,12 @@ class NamedSharding(jsharding.Sharding):
|
||||
# mesh across multiple NamedSharding objects will be the same.
|
||||
return not self.mesh.is_multi_process
|
||||
|
||||
@property
|
||||
def is_concrete(self) -> bool:
|
||||
if isinstance(self.mesh, mesh_lib.AbstractMesh):
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def addressable_devices(self) -> set[Device]:
|
||||
if isinstance(self.mesh, mesh_lib.AbstractMesh):
|
||||
@ -1761,14 +1767,18 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
|
||||
return sharding # type: ignore
|
||||
if sharding is None:
|
||||
return sharding
|
||||
# TODO(yashkatariya): Remove this after vmap + shit works.
|
||||
if isinstance(sharding, NamedSharding) and sharding.mesh.empty:
|
||||
return None
|
||||
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
if isinstance(sharding, PartitionSpec):
|
||||
sharding = NamedSharding(mesh_lib.get_abstract_mesh(), sharding) # type: ignore
|
||||
sharding = NamedSharding(cur_mesh, sharding) # type: ignore
|
||||
else:
|
||||
if (check_mesh_consistency and
|
||||
sharding.mesh.abstract_mesh != mesh_lib.get_abstract_mesh()):
|
||||
if (check_mesh_consistency and not cur_mesh.empty and
|
||||
sharding.mesh.abstract_mesh != cur_mesh):
|
||||
raise ValueError(
|
||||
f'Context mesh {mesh_lib.get_abstract_mesh()} should match the mesh'
|
||||
f'Context mesh {cur_mesh} should match the mesh'
|
||||
f' of sharding {sharding.mesh.abstract_mesh}. This error occurs at'
|
||||
f' source: {source_info_util.summarize(source_info_util.current())}')
|
||||
if isinstance(sharding.mesh, mesh_lib.Mesh):
|
||||
|
@ -737,7 +737,7 @@ class Traced(Stage):
|
||||
"_args_flat", "_arg_names", "_num_consts"]
|
||||
|
||||
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
|
||||
lower_callable, abstract_mesh=None,
|
||||
lower_callable, abstract_mesh=mesh_lib.AbstractMesh(()),
|
||||
args_flat=None, arg_names=None, num_consts: int = 0):
|
||||
self.jaxpr = jaxpr
|
||||
self.args_info = args_info
|
||||
|
@ -320,6 +320,15 @@ class AbstractRef(core.AbstractValue):
|
||||
f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`."
|
||||
) from None
|
||||
|
||||
@property
|
||||
def sharding(self):
|
||||
try:
|
||||
return self.inner_aval.sharding # pytype: disable=attribute-error
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
f"`Ref{{{self.inner_aval.str_short()}}} has no `sharding`."
|
||||
) from None
|
||||
|
||||
@core.aval_property
|
||||
def at(self):
|
||||
return RefIndexer(self)
|
||||
|
@ -70,7 +70,6 @@ from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef,
|
||||
generate_key_paths, KeyPath)
|
||||
from jax.experimental.multihost_utils import (host_local_array_to_global_array,
|
||||
global_array_to_host_local_array)
|
||||
from jax._src.pjit import sharding_constraint_p
|
||||
|
||||
P = PartitionSpec
|
||||
|
||||
@ -473,6 +472,12 @@ shard_map_p = ShardMapPrimitive('shard_map')
|
||||
|
||||
# Staging
|
||||
|
||||
def _as_manual_mesh(mesh):
|
||||
if config.sharding_in_types.value:
|
||||
return AbstractMesh(
|
||||
mesh.shape_tuple, axis_types={AxisTypes.Manual: mesh.axis_names})
|
||||
return None
|
||||
|
||||
def _shard_map_staging(
|
||||
trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun,
|
||||
in_tracers: Sequence[Any], *, mesh: Mesh,
|
||||
@ -485,8 +490,9 @@ def _shard_map_staging(
|
||||
in_tracers = map(trace.to_jaxpr_tracer, in_tracers)
|
||||
in_avals = [t.aval for t in in_tracers]
|
||||
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
|
||||
manual_mesh = _as_manual_mesh(mesh)
|
||||
with (core.extend_axis_env_nd(list(mesh.shape.items())),
|
||||
set_abstract_mesh(pjit.get_abstract_mesh_from_avals(in_avals_))):
|
||||
set_abstract_mesh(manual_mesh)):
|
||||
jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_)
|
||||
_check_names(out_names_thunk(), out_avals_)
|
||||
if check_rep:
|
||||
@ -502,7 +508,8 @@ def _shard_map_staging(
|
||||
constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts))
|
||||
outvars = map(trace.makevar, out_tracers)
|
||||
in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore
|
||||
with core.extend_axis_env_nd(list(mesh.shape.items())):
|
||||
with (core.extend_axis_env_nd(list(mesh.shape.items())),
|
||||
set_abstract_mesh(manual_mesh)):
|
||||
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
||||
params = dict(mesh=mesh, in_names=in_names_staged,
|
||||
out_names=tuple(out_names_thunk()), jaxpr=jaxpr,
|
||||
@ -552,7 +559,9 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames,
|
||||
for i, sz in enumerate(aval.shape))
|
||||
if config.sharding_in_types.value:
|
||||
spec = _names_to_pspec(names)._normalized_spec(aval.ndim)
|
||||
new_sharding = NamedSharding(get_abstract_mesh(), spec)
|
||||
new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else
|
||||
get_abstract_mesh())
|
||||
new_sharding = NamedSharding(new_mesh, spec)
|
||||
else:
|
||||
new_sharding = None
|
||||
return aval.update(shape=new_shape, sharding=new_sharding)
|
||||
@ -787,7 +796,8 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
|
||||
del prim, auto
|
||||
if isinstance(mesh, AbstractMesh):
|
||||
mesh = get_mesh_from_args(args, mesh)
|
||||
args = map(partial(_unmatch_spec, mesh), in_names, args)
|
||||
args = map(partial(_unmatch_spec, mesh, context_mesh=get_abstract_mesh()),
|
||||
in_names, args)
|
||||
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
|
||||
outs, out_rep = _run_shmap(fun, mesh, args, in_rep, check_rep)
|
||||
out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs]
|
||||
@ -798,11 +808,14 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
|
||||
return map(partial(_match_spec, mesh, check_rep), pspecs, outs)
|
||||
core.EvalTrace.process_shard_map = _shard_map_impl
|
||||
|
||||
def _run_shmap(f, mesh, args, reps, check_rep):
|
||||
trace = ShardMapTrace(mesh, check_rep)
|
||||
def _run_shmap(f, mesh, args, reps, check_rep, context_mesh=None):
|
||||
context_mesh = get_abstract_mesh() if context_mesh is None else context_mesh
|
||||
trace = ShardMapTrace(mesh, check_rep, context_mesh)
|
||||
in_tracers = map(partial(ShardMapTracer, trace), reps, args)
|
||||
with core.set_current_trace(trace):
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
manual_mesh = _as_manual_mesh(mesh)
|
||||
with (core.set_current_trace(trace),
|
||||
core.extend_axis_env_nd(mesh.shape.items()),
|
||||
set_abstract_mesh(manual_mesh)):
|
||||
ans = f.call_wrapped(*in_tracers)
|
||||
outs, out_rep = unzip2(map(trace.to_val_rep_pair, ans))
|
||||
return outs, out_rep
|
||||
@ -812,8 +825,9 @@ def _names_to_pspec(names: AxisNames) -> PartitionSpec:
|
||||
unpack = lambda t: t[0] if t is not None and len(t) == 1 else t
|
||||
return PartitionSpec(*(unpack(names.get(i)) for i in range(ndmin)))
|
||||
|
||||
def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType) -> JaxType:
|
||||
with core.eval_context(), jax.disable_jit(False):
|
||||
def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType, context_mesh) -> JaxType:
|
||||
with (core.eval_context(), jax.disable_jit(False),
|
||||
set_abstract_mesh(context_mesh)):
|
||||
return jax.jit(HashablePartial(_unmatch, mesh, tuple(src.items())))(x)
|
||||
|
||||
def _unmatch(mesh, src_tup, x):
|
||||
@ -850,18 +864,20 @@ def _match(mesh, check_rep, pspec, x):
|
||||
# TODO put back (?) needed for rep checking in eager? for now test rewrite
|
||||
return shard_map(_rem_singleton, mesh, (src,), pspec, check_rep=False)(x)
|
||||
|
||||
def _rem_singleton(x): return x.reshape(x.shape[1:])
|
||||
def _add_singleton(x): return x.reshape(1, *x.shape)
|
||||
def _rem_singleton(x): return jnp.squeeze(x, axis=0)
|
||||
def _add_singleton(x): return jnp.expand_dims(x, axis=0)
|
||||
|
||||
class ShardMapTrace(core.Trace):
|
||||
__slots__ = ("mesh", "check")
|
||||
__slots__ = ("mesh", "check", "context_mesh")
|
||||
|
||||
mesh: Mesh
|
||||
check: bool
|
||||
context_mesh: AbstractMesh
|
||||
|
||||
def __init__(self, mesh, check):
|
||||
def __init__(self, mesh, check, context_mesh):
|
||||
self.mesh = mesh
|
||||
self.check = check
|
||||
self.context_mesh = context_mesh
|
||||
|
||||
def to_val_rep_pair(self, val):
|
||||
if isinstance(val, ShardMapTracer):
|
||||
@ -869,7 +885,7 @@ class ShardMapTrace(core.Trace):
|
||||
elif isinstance(val, Tracer):
|
||||
raise Exception("Shouldn't have any non-shard_map tracers")
|
||||
else:
|
||||
val_ = _unmatch_spec(self.mesh, {}, val)
|
||||
val_ = _unmatch_spec(self.mesh, {}, val, self.context_mesh)
|
||||
return val_, None
|
||||
|
||||
def process_primitive(self, prim, tracers, params):
|
||||
@ -879,7 +895,8 @@ class ShardMapTrace(core.Trace):
|
||||
out_vals = eager_rule(self.mesh, *in_vals, **params)
|
||||
else:
|
||||
f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh)
|
||||
with core.eval_context(), jax.disable_jit(False):
|
||||
with (core.eval_context(), jax.disable_jit(False),
|
||||
set_abstract_mesh(self.context_mesh)):
|
||||
out_vals = jax.jit(f)(*in_vals)
|
||||
rep_rule = _check_rules.get(prim, partial(_rule_missing, prim))
|
||||
out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set()
|
||||
@ -910,7 +927,8 @@ class ShardMapTrace(core.Trace):
|
||||
raise NotImplementedError(msg)
|
||||
del prim, jvp, symbolic_zeros
|
||||
in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers))
|
||||
out_vals, out_rep = _run_shmap(fun, self.mesh, in_vals, in_rep, self.check)
|
||||
out_vals, out_rep = _run_shmap(fun, self.mesh, in_vals, in_rep, self.check,
|
||||
self.context_mesh)
|
||||
return map(partial(ShardMapTracer, self), out_rep, out_vals)
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
|
||||
@ -922,7 +940,8 @@ class ShardMapTrace(core.Trace):
|
||||
raise NotImplementedError(msg)
|
||||
del prim, fwd, bwd, out_trees, symbolic_zeros
|
||||
in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers))
|
||||
out_vals, out_rep = _run_shmap(fun, self.mesh, in_vals, in_rep, self.check)
|
||||
out_vals, out_rep = _run_shmap(fun, self.mesh, in_vals, in_rep, self.check,
|
||||
self.context_mesh)
|
||||
return map(partial(ShardMapTracer, self), out_rep, out_vals)
|
||||
|
||||
|
||||
@ -938,17 +957,23 @@ class ShardMapTracer(core.Tracer):
|
||||
@property
|
||||
def aval(self):
|
||||
aval = core.get_aval(self.val)
|
||||
return core.mapped_aval(self._trace.mesh.size, 0, aval)
|
||||
out = core.mapped_aval(self._trace.mesh.size, 0, aval)
|
||||
if config.sharding_in_types.value:
|
||||
new_sharding = NamedSharding(_as_manual_mesh(self._trace.mesh),
|
||||
out.sharding.spec) # type: ignore
|
||||
else:
|
||||
new_sharding = None
|
||||
return out.update(sharding=new_sharding)
|
||||
|
||||
def to_concrete_value(self):
|
||||
if self.rep == set(self._trace.mesh.axis_names):
|
||||
with core.eval_context():
|
||||
with core.eval_context(), set_abstract_mesh(self._trace.context_mesh):
|
||||
return core.to_concrete_value(self.val[0])
|
||||
else:
|
||||
return None
|
||||
|
||||
def __str__(self) -> str:
|
||||
with core.eval_context():
|
||||
with core.eval_context(), set_abstract_mesh(self._trace.context_mesh):
|
||||
blocks = list(self.val)
|
||||
mesh = self._trace.mesh
|
||||
axis_names = f"({', '.join(map(str, mesh.axis_names))},)"
|
||||
@ -1135,7 +1160,7 @@ for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(),
|
||||
|
||||
for p in [control_flow.loops.cumsum_p, control_flow.loops.cumlogsumexp_p,
|
||||
control_flow.loops.cumprod_p, control_flow.loops.cummax_p,
|
||||
control_flow.loops.cummin_p, sharding_constraint_p]:
|
||||
control_flow.loops.cummin_p, pjit.sharding_constraint_p]:
|
||||
register_standard_check(p)
|
||||
register_standard_rewrite(p)
|
||||
|
||||
@ -1706,11 +1731,11 @@ def _partial_eval_jaxpr_custom_rule(
|
||||
for var, w in zip(jaxpr_staged.invars[:num_res], which) if w]
|
||||
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
|
||||
eqn.primitive, params_known, jaxpr_known.effects,
|
||||
eqn.source_info)
|
||||
eqn.source_info, eqn.ctx)
|
||||
full_res = subs_list2(in_fwd, out_fwd, ins_known, out_binders_known, residuals)
|
||||
eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged,
|
||||
eqn.primitive, params_staged,
|
||||
jaxpr_staged.effects, eqn.source_info)
|
||||
jaxpr_staged.effects, eqn.source_info, eqn.ctx)
|
||||
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
|
||||
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
||||
if type(x) is core.Var and not inst]
|
||||
@ -1795,7 +1820,7 @@ def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn
|
||||
new_eqn = pe.new_jaxpr_eqn(
|
||||
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
||||
[x for x, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, effs, eqn.source_info)
|
||||
eqn.primitive, new_params, effs, eqn.source_info, eqn.ctx)
|
||||
return used_inputs, new_eqn
|
||||
pe.dce_rules[shard_map_p] = _shard_map_dce
|
||||
|
||||
|
@ -1326,6 +1326,12 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
axis_types={jax.sharding.AxisTypes.Auto: 'x'})
|
||||
|
||||
def test_make_mesh_axis_types(self):
|
||||
mesh1 = jax.sharding.AbstractMesh(
|
||||
(('x', 2),), axis_types={jax.sharding.AxisTypes.Auto: 'x'})
|
||||
mesh2 = jax.sharding.AbstractMesh(
|
||||
(('x', 2),), axis_types={jax.sharding.AxisTypes.Auto: ('x',)})
|
||||
self.assertEqual(mesh1, mesh2)
|
||||
|
||||
mesh = jax.make_mesh((1, 1), ('x', 'y'))
|
||||
self.assertDictEqual(mesh.axis_types, {AxisTypes.Auto: ('x', 'y')})
|
||||
|
||||
|
@ -3523,20 +3523,17 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
f = pjit(mul, in_shardings=ns, out_shardings=ns)
|
||||
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
with (jtu.count_pjit_cpp_cache_miss() as count,
|
||||
jtu.count_jit_and_pmap_lowerings() as lowering_count):
|
||||
out = f(arr)
|
||||
cache_info1 = pxla._cached_lowering_to_hlo.cache_info()
|
||||
self.assertIsInstance(out.sharding, NamedSharding)
|
||||
|
||||
out2 = f(np_arr)
|
||||
cache_info2 = pxla._cached_lowering_to_hlo.cache_info()
|
||||
self.assertIsInstance(out2.sharding, NamedSharding)
|
||||
|
||||
# Drops out of C++ cache i.e. cache miss
|
||||
self.assertEqual(count(), 2)
|
||||
# Still gets a hit on pjit_lower cache.
|
||||
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
||||
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
||||
self.assertEqual(lowering_count(), 1)
|
||||
|
||||
def test_list_in_pspec(self):
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
@ -4913,8 +4910,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out[0].sharding, arr1.sharding)
|
||||
self.assertEqual(out[1].sharding, arr2.sharding)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
@jtu.with_user_mesh((4,), ('x',))
|
||||
def test_dot_general_out_sharding(self, mesh):
|
||||
def test_dot_general_out_sharding(self, jit, mesh):
|
||||
np_inp1 = np.arange(16.).reshape(8, 2)
|
||||
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None)))
|
||||
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x')))
|
||||
@ -4924,20 +4922,24 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.aval.sharding.spec, P('x', None))
|
||||
return jnp.sum(out)
|
||||
|
||||
out = f(arr1, arr2)
|
||||
self.assertArraysEqual(out, np.sum(np_inp1 @ np_inp1.T))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
out = f(arr1, arr2)
|
||||
self.assertArraysEqual(out, np.sum(np_inp1 @ np_inp1.T))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'PartitionSpec cannot contain axis names that are of type Auto or'
|
||||
' Manual'):
|
||||
auto_axes(f, out_shardings=P())(arr1, arr2)
|
||||
|
||||
out = jax.grad(f, argnums=(0, 1))(arr1, arr2)
|
||||
self.assertEqual(out[0].sharding, arr1.sharding)
|
||||
self.assertEqual(out[1].sharding, arr2.sharding)
|
||||
|
||||
if jit:
|
||||
jitted_grad = jax.jit(jax.grad(f, argnums=(0, 1)))
|
||||
out = jitted_grad(arr1, arr2)
|
||||
self.assertEqual(out[0].sharding, arr1.sharding)
|
||||
@ -5977,7 +5979,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
{AxisTypes.Explicit: ('x',)})
|
||||
out = f(arr)
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
self.assertDictEqual(out.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'})
|
||||
self.assertDictEqual(out.sharding.mesh.axis_types, {AxisTypes.Auto: ('x',)})
|
||||
|
||||
@jtu.with_user_mesh((2,), 'x')
|
||||
def test_inputs_different_context(self, mesh):
|
||||
@ -6379,6 +6381,19 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
f(arr) # doesn't crash
|
||||
|
||||
def test_shaped_array_input_to_jit_no_sharding(self):
|
||||
# export_test.py has similar test but it's more complicated. This is a
|
||||
# simplified version of a part of that test.
|
||||
aval = core.ShapedArray((8,), jnp.int32)
|
||||
aval2 = core.ShapedArray((8,), jnp.int32)
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
return x * y
|
||||
|
||||
lowered_text = f.lower(aval, aval2).as_text()
|
||||
self.assertNotIn("mhlo.sharding", lowered_text)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
@ -58,7 +58,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, np_inp * np_inp * 4)
|
||||
|
||||
def test_output_sharded_alike_input(self):
|
||||
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp = jax.device_put(np_inp, s)
|
||||
@ -73,7 +73,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, np_inp * 2)
|
||||
|
||||
def test_arange_shard_alike_jit(self):
|
||||
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp = jax.device_put(np_inp, s)
|
||||
|
Loading…
x
Reference in New Issue
Block a user