mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[sharding_in_types] Make vmap
work with shard_map + pallas
PiperOrigin-RevId: 718578207
This commit is contained in:
parent
cd51e9dd14
commit
704b2e5fba
@ -39,6 +39,7 @@ from jax._src import config
|
||||
from jax._src import effects
|
||||
from jax._src import compute_on
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.mesh import AxisTypes
|
||||
from jax._src.partition_spec import PartitionSpec as P
|
||||
from jax._src.errors import (
|
||||
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
|
||||
@ -1687,7 +1688,9 @@ def _invalid_shape_error(shape: Shape, context: str=""):
|
||||
|
||||
# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
|
||||
# Collective too.
|
||||
def modify_spec_for_hidden(spec, mesh) -> P:
|
||||
def modify_spec_for_hidden_collective(spec, mesh) -> P:
|
||||
if all(s is None for s in spec):
|
||||
return spec
|
||||
new_spec = [] # type: ignore
|
||||
for s in spec:
|
||||
if s is None:
|
||||
@ -1695,13 +1698,15 @@ def modify_spec_for_hidden(spec, mesh) -> P:
|
||||
else:
|
||||
temp_s = s[0] if isinstance(s, tuple) else s
|
||||
new_spec.append(
|
||||
None if mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Hidden else s)
|
||||
None
|
||||
if mesh._name_to_type[temp_s] in (AxisTypes.Hidden, AxisTypes.Collective)
|
||||
else s)
|
||||
return P(*new_spec)
|
||||
|
||||
def _maybe_modify_sharding(sharding):
|
||||
if mesh_lib.AxisTypes.Hidden not in sharding.mesh.axis_types:
|
||||
if sharding.mesh._are_all_axes_visible:
|
||||
return sharding
|
||||
new_spec = modify_spec_for_hidden(sharding.spec, sharding.mesh)
|
||||
new_spec = modify_spec_for_hidden_collective(sharding.spec, sharding.mesh)
|
||||
return sharding.with_spec(new_spec)
|
||||
|
||||
|
||||
|
@ -5585,7 +5585,7 @@ def _reduce_prod_jvp_rule(primals, tangents, *, axes):
|
||||
|
||||
reduce_prod_p = standard_primitive(
|
||||
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'),
|
||||
'reduce_prod')
|
||||
'reduce_prod', sharding_rule=_reduce_op_sharding_rule)
|
||||
ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule
|
||||
batching.defreducer(reduce_prod_p, _get_prod_identity)
|
||||
pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, _reduce_prod,
|
||||
@ -5613,8 +5613,9 @@ pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max,
|
||||
batching.ragged_prop_rules[reduce_max_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
|
||||
reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
|
||||
'reduce_min')
|
||||
reduce_min_p = standard_primitive(
|
||||
_reduce_op_shape_rule, _input_dtype, 'reduce_min',
|
||||
sharding_rule=_reduce_op_sharding_rule)
|
||||
ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule)
|
||||
batching.defreducer(reduce_min_p, _get_min_identity)
|
||||
pe.padding_rules[reduce_min_p] = partial(_reducer_padding, _reduce_min,
|
||||
@ -5705,22 +5706,25 @@ def _reduce_logical_shape_rule(operand, *, axes):
|
||||
raise TypeError(f"logical reduction requires operand dtype bool or int, got {operand.dtype}.")
|
||||
return tuple(np.delete(operand.shape, axes))
|
||||
|
||||
def _reduce_logical_sharding_rule(operand, *, axes):
|
||||
return operand.sharding.with_spec(tuple_delete(operand.sharding.spec, axes))
|
||||
|
||||
reduce_or_p = standard_primitive(
|
||||
_reduce_logical_shape_rule, _input_dtype, 'reduce_or',
|
||||
weak_type_rule=_strip_weak_type)
|
||||
weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule)
|
||||
batching.defreducer(reduce_or_p, _get_bitwise_or_identity)
|
||||
|
||||
|
||||
reduce_and_p = standard_primitive(
|
||||
_reduce_logical_shape_rule, _input_dtype, 'reduce_and',
|
||||
weak_type_rule=_strip_weak_type)
|
||||
weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule)
|
||||
batching.defreducer(reduce_and_p, _get_bitwise_and_identity)
|
||||
batching.ragged_prop_rules[reduce_and_p] = batching.ragged_mask_elementwise_rule
|
||||
|
||||
|
||||
reduce_xor_p = standard_primitive(
|
||||
_reduce_logical_shape_rule, _input_dtype, 'reduce_xor',
|
||||
weak_type_rule=_strip_weak_type)
|
||||
weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule)
|
||||
batching.defreducer(reduce_xor_p, _get_bitwise_or_identity)
|
||||
|
||||
|
||||
|
@ -50,7 +50,8 @@ def _get_array_abstraction_level(a): return a.array_abstraction_level
|
||||
def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
|
||||
if config.sharding_in_types.value:
|
||||
if rule is None:
|
||||
if mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
if cur_mesh._are_all_axes_hidden or cur_mesh._are_all_axes_collective: # type: ignore
|
||||
return None if num_out is None else [None] * num_out
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -475,6 +475,10 @@ class AbstractMesh:
|
||||
def _are_all_axes_hidden(self) -> bool:
|
||||
return all(t == AxisTypes.Hidden for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _are_all_axes_visible(self) -> bool:
|
||||
return all(t == AxisTypes.Visible for t in self.axis_types.keys())
|
||||
|
||||
@functools.cached_property
|
||||
def _any_axis_collective(self) -> bool:
|
||||
return any(t == AxisTypes.Collective for t in self.axis_types.keys())
|
||||
|
@ -921,7 +921,8 @@ def jaxpr_subcomp(
|
||||
name_stack=ctx.name_stack + eqn.source_info.name_stack
|
||||
)
|
||||
loc = mlir._source_info_to_location(ctx, eqn.primitive, source_info)
|
||||
with source_info_util.user_context(eqn.source_info.traceback), loc:
|
||||
with (source_info_util.user_context(eqn.source_info.traceback), loc,
|
||||
eqn.ctx.manager):
|
||||
if eqn.primitive in lowering_rules:
|
||||
if eqn.primitive not in skip_mlir_conversions:
|
||||
invals = [_ensure_mlir_value(x, v.aval)
|
||||
|
@ -1117,10 +1117,15 @@ def _pallas_call_batching_rule(
|
||||
# assert ragged_axis_length is not None
|
||||
args = (ragged_axis_length, *args)
|
||||
assert all(isinstance(aval, jax_core.ShapedArray) for aval in out_avals)
|
||||
batched_out_avals = tuple(
|
||||
aval.update(shape=tuple_insert(aval.shape, 0, axis_size))
|
||||
for aval in out_avals
|
||||
)
|
||||
|
||||
batched_out_avals = []
|
||||
for aval in out_avals:
|
||||
sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, 0, None))
|
||||
if config.sharding_in_types.value else None)
|
||||
shape = tuple_insert(aval.shape, 0, axis_size)
|
||||
batched_out_avals.append(aval.update(shape=shape, sharding=sharding))
|
||||
batched_out_avals = tuple(batched_out_avals) # type: ignore
|
||||
|
||||
out = pallas_call_p.bind(
|
||||
*dynamic_grid_args,
|
||||
*args,
|
||||
|
@ -2840,7 +2840,7 @@ def hidden_axes(fun, *, axes: str | tuple[str, ...] | None = None,
|
||||
def decorator(*args, **kwargs):
|
||||
new_mesh = _get_new_mesh(axes, mesh_lib.AxisTypes.Hidden)
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
in_specs = tree_map(lambda a: core.modify_spec_for_hidden(
|
||||
in_specs = tree_map(lambda a: core.modify_spec_for_hidden_collective(
|
||||
a.aval.sharding.spec, new_mesh), args)
|
||||
args = mesh_cast(args, in_specs)
|
||||
out = fun(*args, **kwargs)
|
||||
@ -2861,7 +2861,7 @@ def visible_axes(fun, *, axes: str | tuple[str, ...] | None = None,
|
||||
with mesh_lib.set_abstract_mesh(new_mesh):
|
||||
args = mesh_cast(args, in_shardings)
|
||||
out = fun(*args, **kwargs)
|
||||
out_specs = tree_map(lambda o: core.modify_spec_for_hidden(
|
||||
out_specs = tree_map(lambda o: core.modify_spec_for_hidden_collective(
|
||||
o.aval.sharding.spec, mesh_lib.get_abstract_mesh()), out)
|
||||
return mesh_cast(out, out_specs)
|
||||
return decorator
|
||||
|
Loading…
x
Reference in New Issue
Block a user