[shard-map] fix eager shmap+prngs, revise phys aval/sharding logic

Co-authored-by: Yash Katariya <yashkatariya@google.com>
This commit is contained in:
Yash Katariya 2023-04-05 14:09:46 -07:00 committed by Matthew Johnson
parent f6da71c807
commit 728a5ed96a
7 changed files with 123 additions and 102 deletions

View File

@ -545,11 +545,15 @@ def flatten_lowering_ir_args(
_module_name_regex = re.compile(r"[^\w.-]")
def sharded_aval(aval: core.ShapedArray,
sharding: Optional[xc.OpSharding]) -> core.ShapedArray:
def sharded_aval(aval: core.AbstractValue,
sharding: Optional[xc.OpSharding]) -> core.AbstractValue:
"""Returns the new aval sharded based on sharding proto."""
if sharding is None:
return aval
if isinstance(aval, core.AbstractToken):
return aval
if not isinstance(aval, core.ShapedArray):
raise NotImplementedError
if (sharding.type == xc.OpSharding.Type.REPLICATED or
sharding.type == xc.OpSharding.Type.MANUAL):
@ -620,22 +624,10 @@ def lower_jaxpr_to_module(
if not xb.is_known_platform(platform):
raise ValueError(f"Unknown platform {platform}")
input_output_aliases = None
in_avals = jaxpr.in_avals
if arg_shardings is not None:
in_avals = [
sharded_aval(in_aval, in_sharding)
for in_aval, in_sharding in zip(in_avals, arg_shardings)
]
out_avals = jaxpr.out_avals
if result_shardings is not None:
out_avals = []
for out_aval, out_sharding in zip(jaxpr.out_avals, result_shardings):
if (out_aval is not core.abstract_token and
core.is_opaque_dtype(out_aval.dtype)):
# TODO(frostig,mattjj,necula): asserts a single physical aval
out_aval, = out_aval.dtype._rules.physical_avals(out_aval)
out_avals.append(sharded_aval(out_aval, out_sharding))
in_avals = (jaxpr.in_avals if arg_shardings is None else
map(sharded_aval, jaxpr.in_avals, arg_shardings))
out_avals = (jaxpr.out_avals if result_shardings is None else
map(sharded_aval, jaxpr.out_avals, result_shardings))
if platform in _platforms_with_donation:
input_output_aliases, donated_args = _set_up_aliases(
in_avals, out_avals, donated_args)
@ -643,9 +635,7 @@ def lower_jaxpr_to_module(
if unlowerable_effects:
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')
if any(donated_args):
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
unused_donations = [str(a) for a, d in zip(in_avals, donated_args)
if d]
unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d]
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
if platform not in _platforms_with_donation:
msg = f"Donation is not implemented for {platform}.\n{msg}"
@ -660,8 +650,7 @@ def lower_jaxpr_to_module(
dim_vars: Sequence[str]
if not config.jax_dynamic_shapes:
# Find the dimension variables
all_dim_poly = [d
for aval in jaxpr.in_avals if hasattr(aval, "shape")
all_dim_poly = [d for aval in jaxpr.in_avals if hasattr(aval, "shape")
for d in aval.shape if not core.is_constant_dim(d)]
dim_vars = tuple(sorted(functools.reduce(lambda acc, new: acc.union(new.get_vars()),
all_dim_poly, set())))
@ -886,14 +875,18 @@ def lower_jaxpr_to_fun(
ctx.symbol_table.insert(func_op)
ir_arg_shardings = None
if arg_shardings is not None:
in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals)
ir_arg_shardings = util.flatten(
[[sharding] * len(types) for sharding, types
in zip(arg_shardings, input_types)])
[[_to_physical_op_sharding(a, s)] * len(types)
for a, s, types in zip(in_avals, arg_shardings, input_types)])
del in_avals
ir_result_shardings = None
if result_shardings is not None:
out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals)
ir_result_shardings = util.flatten(
[[sharding] * len(types)
for sharding, types in zip(result_shardings, output_types)])
[[_to_physical_op_sharding(a, s)] * len(types)
for a, s, types in zip(out_avals, result_shardings, output_types)])
del out_avals
if (replicated_args is not None or ir_arg_shardings is not None
or input_output_aliases is not None):
@ -1001,6 +994,15 @@ def lower_jaxpr_to_fun(
return func_op
def _to_physical_op_sharding(
aval: Optional[core.AbstractValue], sharding: Optional[xc.OpSharding]
) -> Optional[xc.OpSharding]:
if (isinstance(aval, core.ShapedArray) and core.is_opaque_dtype(aval.dtype)
and sharding is not None):
return aval.dtype._rules.physical_op_sharding(aval, sharding)
return sharding
def _emit_lowering_rule_as_fun(lowering_rule,
ctx: LoweringRuleContext) -> func_dialect.FuncOp:
"""Emits the contents of a lowering rule as a private function."""

View File

@ -2327,13 +2327,15 @@ def _get_and_check_device_assignment(
return xb.get_device_backend(final_device_assignment[0]), final_device_assignment
MaybeSharding = Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue]
@profiler.annotate_function
def lower_sharding_computation(
fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr],
api_name: str,
fun_name: str,
in_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue]],
out_shardings: Union[Sequence[Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue]], UnspecifiedValue],
in_shardings: Sequence[MaybeSharding],
out_shardings: Union[Sequence[MaybeSharding], UnspecifiedValue],
donated_invars: Sequence[bool],
global_in_avals: Sequence[core.ShapedArray],
*,
@ -2384,17 +2386,17 @@ def lower_sharding_computation(
if _is_unspecified(out_shardings):
out_shardings = (_UNSPECIFIED,) * len(global_out_avals)
# mypy doesn't understand that out_sharding here is always a sequence.
assert len(out_shardings) == len(global_out_avals), ( # type: ignore
len(out_shardings), len(global_out_avals)) # type: ignore
assert isinstance(out_shardings, tuple)
assert len(out_shardings) == len(global_out_avals), (
len(out_shardings), len(global_out_avals))
# Device assignment across all inputs, outputs and shardings inside jaxpr
# should be the same.
jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr))
backend, device_assignment = _get_and_check_device_assignment(
it.chain([(i, MismatchType.ARG_SHARDING, None) for i in in_shardings],
[(o, MismatchType.OUT_SHARDING, None) for o in out_shardings], # type: ignore
[(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) # type: ignore
[(o, MismatchType.OUT_SHARDING, None) for o in out_shardings],
[(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
for js, source_info in jaxpr_sharding]),
devices_from_context)
@ -2402,8 +2404,8 @@ def lower_sharding_computation(
devices_from_context or
len(device_assignment) > 1 or
any(not _is_unspecified(i) for i in in_shardings) or
any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or # type: ignore
any(not _is_unspecified(o) for o in out_shardings)) # type: ignore
any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or
any(not _is_unspecified(o) for o in out_shardings))
in_shardings = tuple(sharding_impls.GSPMDSharding.get_replicated(device_assignment)
if _is_unspecified(i) else i for i in in_shardings)
@ -2445,7 +2447,7 @@ def lower_sharding_computation(
# and don't need to evaluate their arguments.
if (not always_lower and not (jaxpr.effects or has_outfeed) and
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and
all(_is_unspecified(o) for o in out_shardings)): # type: ignore
all(_is_unspecified(o) for o in out_shardings)):
return MeshComputation(
str(name_stack), None, True, donated_invars, jaxpr=jaxpr, consts=consts,
global_in_avals=global_in_avals, global_out_avals=global_out_avals,
@ -2469,25 +2471,8 @@ def lower_sharding_computation(
axis_ctx: mlir.AxisContext
if nreps == 1:
in_op_shardings = []
for aval, i in safe_zip(global_in_avals, in_shardings):
if aval is core.abstract_token:
in_op_shardings.append(None)
elif core.is_opaque_dtype(aval.dtype):
in_op_shardings.append(aval.dtype._rules.physical_op_sharding(aval, i))
else:
in_op_shardings.append(i._to_xla_op_sharding(aval.ndim)) # type: ignore[union-attr]
# TODO(yashkatariya): Fix the HLO produced if out_partitions is
# [None, OpShardingProto] has the sharding annotations.
out_op_shardings = []
for aval, o in safe_zip(global_out_avals, out_shardings): # type: ignore[arg-type]
if _is_unspecified(o) or aval is core.abstract_token:
out_op_shardings.append(None)
elif core.is_opaque_dtype(aval.dtype):
out_op_shardings.append(aval.dtype._rules.physical_op_sharding(aval, o))
else:
out_op_shardings.append(o._to_xla_op_sharding(aval.ndim)) # type: ignore[union-attr]
in_op_shardings = map(_to_logical_op_sharding, global_in_avals, in_shardings)
out_op_shardings = map(_to_logical_op_sharding, global_out_avals, out_shardings)
replicated_args = [False] * len(global_in_avals)
axis_ctx = mlir.ShardingContext(device_assignment)
else:
@ -2556,6 +2541,19 @@ def lower_sharding_computation(
committed=committed,
pmap_nreps=nreps)
def _to_logical_op_sharding(
aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTOAxisResource]
) -> Optional[xc.OpSharding]:
if _is_unspecified(sharding) or is_auto(sharding):
return None
elif isinstance(aval, ShapedArray):
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
return sharding._to_xla_op_sharding(aval.ndim)
elif isinstance(aval, core.AbstractToken):
return None
else:
raise TypeError(aval)
@profiler.annotate_function
def lower_mesh_computation(
@ -2575,7 +2573,7 @@ def lower_mesh_computation(
backend = xb.get_device_backend(mesh.devices.flat[0])
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
auto_spmd_lowering = check_if_any_auto(in_shardings + out_shardings) # type: ignore
auto_spmd_lowering = check_if_any_auto((*in_shardings, *out_shardings))
if auto_spmd_lowering and not spmd_lowering:
raise ValueError('Enable spmd_lowering to use auto spmd lowering.')
@ -2653,25 +2651,8 @@ def lower_mesh_computation(
out_partitions: Optional[List[Optional[xc.OpSharding]]]
axis_ctx: mlir.AxisContext
if spmd_lowering:
in_partitions = []
for aval, i in safe_zip(global_in_avals, in_shardings):
if is_auto(i):
in_partitions.append(None)
elif core.is_opaque_dtype(aval.dtype):
in_partitions.append(aval.dtype._rules.physical_op_sharding(aval, i))
else:
in_partitions.append(i._to_xla_op_sharding(aval.ndim)) # type: ignore[union-attr]
# TODO(yashkatariya): Fix the HLO produced if out_partitions is
# [None, OpShardingProto] has the sharding annotations.
out_partitions = []
for aval, o in safe_zip(global_out_avals, out_shardings):
if is_auto(o) or _is_unspecified(o):
out_partitions.append(None)
elif core.is_opaque_dtype(aval.dtype):
out_partitions.append(aval.dtype._rules.physical_op_sharding(aval, o))
else:
out_partitions.append(o._to_xla_op_sharding(aval.ndim)) # type: ignore[union-attr]
in_partitions = map(_to_logical_op_sharding, global_in_avals, in_shardings)
out_partitions = map(_to_logical_op_sharding, global_out_avals, out_shardings)
replicated_args = [False] * len(in_jaxpr_avals)
axis_ctx = mlir.SPMDAxisContext(mesh, manual_axes)
else:

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
from functools import partial, reduce
@ -176,7 +176,12 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
_device = property(op.attrgetter('_base_array._device'))
_committed = property(op.attrgetter('_base_array._committed'))
sharding = property(op.attrgetter('_base_array.sharding'))
@property
def sharding(self):
aval = keys_shaped_array(self.impl, self.shape)
phys_sharding = self._base_array.sharding
return KeyTyRules.logical_op_sharding(aval, phys_sharding)
def _is_scalar(self):
base_ndim = len(self.impl.key_shape)
@ -250,14 +255,14 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
# `typing.type_check_only`.
@property
def T(self) -> 'PRNGKeyArray': assert False
def __getitem__(self, _) -> 'PRNGKeyArray': assert False
def ravel(self, *_, **__) -> 'PRNGKeyArray': assert False
def squeeze(self, *_, **__) -> 'PRNGKeyArray': assert False
def swapaxes(self, *_, **__) -> 'PRNGKeyArray': assert False
def take(self, *_, **__) -> 'PRNGKeyArray': assert False
def transpose(self, *_, **__) -> 'PRNGKeyArray': assert False
def flatten(self, *_, **__) -> 'PRNGKeyArray': assert False
def T(self) -> PRNGKeyArray: assert False
def __getitem__(self, _) -> PRNGKeyArray: assert False
def ravel(self, *_, **__) -> PRNGKeyArray: assert False
def squeeze(self, *_, **__) -> PRNGKeyArray: assert False
def swapaxes(self, *_, **__) -> PRNGKeyArray: assert False
def take(self, *_, **__) -> PRNGKeyArray: assert False
def transpose(self, *_, **__) -> PRNGKeyArray: assert False
def flatten(self, *_, **__) -> PRNGKeyArray: assert False
_set_device_array_base_attributes(PRNGKeyArray, include=[
@ -303,9 +308,10 @@ def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla):
elif is_sharding_from_xla:
return sharding
else:
sharding_proto = sharding._to_xla_op_sharding(aval.ndim)
return GSPMDSharding(
sharding._device_assignment,
KeyTyRules.physical_op_sharding(aval, sharding))
KeyTyRules.physical_op_sharding(aval, sharding_proto))
class KeyTyRules:
@ -316,16 +322,25 @@ class KeyTyRules:
jnp.dtype('uint32'))]
@staticmethod
def physical_op_sharding(aval, sharding):
op_sharding = sharding._to_xla_op_sharding(aval.ndim)
def physical_op_sharding(aval, op_sharding_proto):
key_shape = aval.dtype.impl.key_shape
new_op_sharding = op_sharding.clone()
new_op_sharding = op_sharding_proto.clone()
tad = list(new_op_sharding.tile_assignment_dimensions)
tad.extend([1] * len(key_shape))
new_op_sharding.tile_assignment_dimensions = tad
return new_op_sharding
@staticmethod
def logical_op_sharding(aval, phys_sharding) -> GSPMDSharding:
key_shape = aval.dtype.impl.key_shape
phys_op_sharding = phys_sharding._to_xla_op_sharding(
aval.ndim + len(key_shape))
logical_op_sharding = phys_op_sharding.clone()
tad = list(logical_op_sharding.tile_assignment_dimensions)
tad = tad[:-len(key_shape)]
logical_op_sharding.tile_assignment_dimensions = tad
return GSPMDSharding(phys_sharding._device_assignment, logical_op_sharding)
@staticmethod
def result_handler(sticky_device, aval):
def handler(_, buf):
@ -367,7 +382,6 @@ class KeyTyRules:
phys_sharding = make_key_array_phys_sharding(
aval, out_sharding, is_out_sharding_from_xla)
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
is_out_sharding_from_xla)
def handler(bufs):

View File

@ -33,6 +33,7 @@ from jax._src import debugging
from jax._src import linear_util as lu
from jax._src import ops
from jax._src import pjit
from jax._src import prng
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
@ -473,9 +474,11 @@ def _xla_shard(mesh, names, aval_in, aval_out, x):
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names), mesh)
result_type, = mlir.aval_to_ir_types(aval_out)
axes = {name: i for i, ns in names.items() for name in ns}
sharding_proto = pxla.new_mesh_sharding_specs(mesh.shape, mesh.axis_names)(
aval_in.ndim, axes).sharding_proto()
sx = mlir.wrap_with_sharding_op(x, sharding_proto, unspecified_dims=set())
shard_proto = NamedSharding(mesh, pxla.array_mapping_to_axis_resources(axes) # type: ignore
)._to_xla_op_sharding(aval_in.ndim)
if core.is_opaque_dtype(aval_in.dtype):
shard_proto = aval_in.dtype._rules.physical_op_sharding(aval_in, shard_proto)
sx = mlir.wrap_with_sharding_op(x, shard_proto, unspecified_dims=set())
return [mlir.wrap_with_full_to_shard_op(result_type, sx, manual_proto, set())]
def _xla_unshard(mesh, names, aval_in, aval_out, xs):
@ -484,9 +487,11 @@ def _xla_unshard(mesh, names, aval_in, aval_out, xs):
result_type, = mlir.aval_to_ir_types(aval_out)
sx = mlir.wrap_with_sharding_op(x, manual_proto, unspecified_dims=set())
axes = {name: i for i, ns in names.items() for name in ns}
sharding_proto = pxla.new_mesh_sharding_specs(mesh.shape, mesh.axis_names)(
aval_out.ndim, axes).sharding_proto()
return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, set())
shard_proto = NamedSharding(mesh, pxla.array_mapping_to_axis_resources(axes) # type: ignore
)._to_xla_op_sharding(aval_out.ndim)
if core.is_opaque_dtype(aval_out.dtype):
shard_proto = aval_out.dtype._rules.physical_op_sharding(aval_out, shard_proto)
return mlir.wrap_with_shard_to_full_op(result_type, sx, shard_proto, set())
# Eager evaluation
@ -652,7 +657,7 @@ def _standard_rep_rule(_, *in_rep, **__):
for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(),
windowed_reductions.__dict__.values(), fft.__dict__.values(),
linalg.__dict__.values(), ops.__dict__.values(),
ad_util.__dict__.values(),
ad_util.__dict__.values(), prng.__dict__.values(),
custom_derivatives.__dict__.values()):
if isinstance(o, core.Primitive): register_standard(o)

View File

@ -2839,8 +2839,11 @@ class FooTyRules:
return [core.ShapedArray((*aval.shape, 2), jnp.dtype('uint32'))]
@staticmethod
def physical_op_sharding(aval, sharding):
return sharding._to_xla_op_sharding(aval.ndim)
def physical_op_sharding(aval, op_sharding_proto):
new_op_sharding = op_sharding_proto.clone()
tad = list(new_op_sharding.tile_assignment_dimensions)
new_op_sharding.tile_assignment_dimensions = [*tad, 1]
return new_op_sharding
@staticmethod
def result_handler(sticky_device, aval):

View File

@ -36,7 +36,7 @@ from jax import stages
from jax.errors import JAXTypeError
from jax import lax
from jax.lax import with_sharding_constraint
from jax import prng
from jax._src import prng
from jax.sharding import PartitionSpec as P
from jax.experimental.maps import xmap
from jax.experimental import multihost_utils

View File

@ -553,6 +553,22 @@ class ShardMapTest(jtu.JaxTestCase):
jax.eval_shape(jax.grad(lambda x: jax.remat(f)(x).sum().astype('float32')),
xs)
def test_prngkeyarray_eager(self):
# https://github.com/google/jax/issues/15398
mesh = jtu.create_global_mesh((4,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
rng = jax.random.PRNGKey(0)
sharded_rng = jax.random.split(rng, num=4)
sharded_rng = jax.device_put(sharded_rng, sharding)
def f(key):
return jax.random.randint(key[0], shape=(1, 16), minval=0, maxval=16,
dtype=jnp.int32)
g = shard_map(f, mesh, in_specs=(P('x', None),), out_specs=P('x', None))
_ = g(sharded_rng) # don't crash!
class FunSpec(NamedTuple):
name: str