mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00

Previously it was: `ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x` Now it is: `TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]` PiperOrigin-RevId: 736657644
170 lines
6.9 KiB
Python
170 lines
6.9 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# This module contains utility functions split out of jax._src.lax.lax to
|
|
# avoid cyclic dependencies. Definitions that are used at import time by
|
|
# multiple modules can go here.
|
|
|
|
from functools import partial
|
|
|
|
from jax._src import core
|
|
from jax._src import dispatch
|
|
from jax._src import dtypes
|
|
from jax._src import mesh as mesh_lib
|
|
from jax._src.util import safe_zip
|
|
from jax._src.partition_spec import PartitionSpec as P
|
|
from jax._src.named_sharding import NamedSharding, DuplicateSpecError
|
|
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
import numpy as np
|
|
|
|
def _input_dtype(x, *_, **__):
|
|
return dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
|
|
|
|
def _argnum_weak_type(*argnums):
|
|
return lambda *args, **_: all(args[i].weak_type for i in argnums)
|
|
|
|
def standard_primitive(shape_rule, dtype_rule, name,
|
|
weak_type_rule=None, sharding_rule=None):
|
|
weak_type_rule = weak_type_rule or _standard_weak_type_rule
|
|
prim = core.Primitive(name)
|
|
prim.def_impl(partial(dispatch.apply_primitive, prim))
|
|
prim.def_abstract_eval(
|
|
partial(standard_abstract_eval, prim, shape_rule, dtype_rule,
|
|
weak_type_rule, sharding_rule))
|
|
return prim
|
|
|
|
def _get_array_abstraction_level(a): return a.array_abstraction_level
|
|
|
|
def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh:
|
|
m = None
|
|
for a in in_avals:
|
|
if a is core.abstract_token:
|
|
continue
|
|
if a.sharding.mesh.empty:
|
|
continue
|
|
if m is not None and m != a.sharding.mesh:
|
|
if m._are_all_axes_auto and a.sharding.mesh._are_all_axes_auto:
|
|
return mesh_lib.empty_abstract_mesh
|
|
raise ValueError(
|
|
f'Mesh for all inputs should be equal. Got one mesh: {m} and'
|
|
f' another mesh: {a.sharding.mesh}')
|
|
m = a.sharding.mesh
|
|
return mesh_lib.empty_abstract_mesh if m is None else m
|
|
|
|
|
|
def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
|
|
cur_mesh = mesh_lib.get_abstract_mesh()
|
|
aval_mesh = _get_abstract_mesh_from_avals(avals)
|
|
if ((cur_mesh.empty or cur_mesh._are_all_axes_auto_or_manual) and
|
|
(aval_mesh.empty or aval_mesh._are_all_axes_auto_or_manual)):
|
|
aval_mesh = cur_mesh if aval_mesh.empty else aval_mesh
|
|
s = NamedSharding(aval_mesh, P())
|
|
return s if num_out is None else [s] * num_out
|
|
if rule is None:
|
|
raise ValueError(
|
|
f'sharding rule for {prim.name} is not implemented. Please file a'
|
|
' bug at https://github.com/jax-ml/jax/issues. You can work around'
|
|
' this error by dropping that operation into full auto sharding'
|
|
' mode via: `jax.experimental.shard.auto_axes(fun, out_shardings=...)`')
|
|
return rule(*avals, **kwargs)
|
|
|
|
def call_shape_dtype_sharding_rule(prim, shape_rule, dtype_rule, sharding_rule,
|
|
multi_out, *avals, **kwargs):
|
|
out_shapes = shape_rule(*avals, **kwargs)
|
|
out_dtypes = dtype_rule(*avals, **kwargs)
|
|
num_out = len(out_shapes) if multi_out else None
|
|
try:
|
|
out_shardings = call_sharding_rule(
|
|
prim, sharding_rule, num_out, *avals, **kwargs)
|
|
except DuplicateSpecError as e:
|
|
if multi_out:
|
|
raise
|
|
avals_str = ', '.join(i.str_short(short_dtypes=True) for i in avals)
|
|
mesh = mesh_lib.empty_abstract_mesh if e.mesh is None else e.mesh
|
|
out_aval_str = core.str_short_aval(out_shapes, out_dtypes, mesh, e.pspec,
|
|
short_dtypes=True)
|
|
raise TypeError(
|
|
f'{prim} operation with inputs: {avals_str} produces an illegally'
|
|
f' sharded result: {out_aval_str}') from e
|
|
return out_shapes, out_dtypes, out_shardings
|
|
|
|
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
|
|
sharding_rule, *avals, **kwargs):
|
|
assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
|
|
assert not prim.multiple_results
|
|
weak_type = weak_type_rule(*avals, **kwargs)
|
|
least_specialized = type(max(avals, key=_get_array_abstraction_level))
|
|
if least_specialized is core.ShapedArray:
|
|
core.check_avals_context_mesh(avals, prim.name)
|
|
out_shape, out_dtype, out_sharding = call_shape_dtype_sharding_rule(
|
|
prim, shape_rule, dtype_rule, sharding_rule, False,
|
|
*avals, **kwargs)
|
|
out_aval = core.ShapedArray(
|
|
out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding)
|
|
core.check_avals_context_mesh([out_aval], prim.name)
|
|
return out_aval
|
|
elif least_specialized is core.DShapedArray:
|
|
shape = shape_rule(*avals, **kwargs)
|
|
ty = (core.ShapedArray if all(type(d) is int for d in shape)
|
|
else core.DShapedArray)
|
|
return ty(shape, dtype_rule(*avals, **kwargs), weak_type)
|
|
elif least_specialized is core.UnshapedArray:
|
|
return core.UnshapedArray(dtype_rule(*avals, **kwargs), weak_type=weak_type)
|
|
else:
|
|
raise TypeError(avals, least_specialized)
|
|
|
|
def standard_multi_result_abstract_eval(
|
|
prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule,
|
|
*avals, **kwargs):
|
|
assert prim.multiple_results
|
|
assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
|
|
least_specialized = max(map(type, avals), key=_get_array_abstraction_level)
|
|
weak_types = weak_type_rule(*avals, **kwargs)
|
|
if least_specialized is core.ShapedArray:
|
|
core.check_avals_context_mesh(avals, prim.name)
|
|
out_shapes, out_dtypes, out_shardings = call_shape_dtype_sharding_rule(
|
|
prim, shape_rule, dtype_rule, sharding_rule, True, *avals, **kwargs)
|
|
if isinstance(weak_types, bool):
|
|
weak_types = (weak_types,) * len(out_shapes)
|
|
out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh)
|
|
for s, d, weak_type, sh in zip(out_shapes, out_dtypes,
|
|
weak_types, out_shardings)]
|
|
core.check_avals_context_mesh(out_avals, prim.name)
|
|
return out_avals
|
|
elif least_specialized is core.UnshapedArray:
|
|
out_dtypes = dtype_rule(*avals, **kwargs)
|
|
if isinstance(weak_types, bool):
|
|
weak_types = (weak_types,) * len(out_dtypes)
|
|
return [core.UnshapedArray(dtype, weak_type=weak_type)
|
|
for dtype, weak_type in zip(out_dtypes, weak_types)]
|
|
else:
|
|
raise TypeError(avals, least_specialized)
|
|
|
|
|
|
def _standard_weak_type_rule(*avals, **kwargs):
|
|
return all(aval.weak_type for aval in avals)
|
|
|
|
def dtype_to_string(dtype):
|
|
try:
|
|
return str(np.dtype(dtype).name)
|
|
except TypeError:
|
|
pass
|
|
try:
|
|
return dtype.name
|
|
except AttributeError:
|
|
pass
|
|
return str(dtype)
|