mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sharding_in_types] Handle ShapeDtypeStruct inputs with sharding_in_types by registering the sharding on the aval properly created by SDS in it's pytype_aval_mapping.
Also If we are running under full auto mode, don't error out if primitives don't have a sharding rule registered. PiperOrigin-RevId: 715383866
This commit is contained in:
parent
a1bbad6863
commit
c72ed260fe
@ -68,7 +68,8 @@ from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
|
||||
from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind,
|
||||
NamedSharding)
|
||||
from jax._src.layout import Layout, AutoLayout
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src import tree_util
|
||||
@ -2562,9 +2563,14 @@ class ShapeDtypeStruct:
|
||||
return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type))
|
||||
|
||||
def _sds_aval_mapping(x):
|
||||
return ShapedArray(
|
||||
aval = ShapedArray(
|
||||
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
|
||||
weak_type=x.weak_type)
|
||||
if config.sharding_in_types.value and isinstance(x.sharding, NamedSharding):
|
||||
return aval.update(sharding=NamedSharding(
|
||||
x.sharding.mesh.abstract_mesh,
|
||||
x.sharding.spec._normalized_spec(x.ndim)))
|
||||
return aval
|
||||
core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping
|
||||
|
||||
|
||||
|
@ -1420,6 +1420,8 @@ def check_valid_jaxtype(x):
|
||||
# TODO(jakevdp): can these be unified further?
|
||||
|
||||
def shaped_abstractify(x):
|
||||
from jax._src.sharding_impls import NamedSharding # type: ignore
|
||||
|
||||
typ = type(x)
|
||||
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
|
||||
return aval_fn(x)
|
||||
@ -1431,7 +1433,14 @@ def shaped_abstractify(x):
|
||||
if hasattr(x, '__jax_array__'):
|
||||
return shaped_abstractify(x.__jax_array__())
|
||||
if hasattr(x, 'dtype'):
|
||||
return ShapedArray(np.shape(x), x.dtype, weak_type=getattr(x, 'weak_type', False))
|
||||
aval = ShapedArray(np.shape(x), x.dtype,
|
||||
weak_type=getattr(x, 'weak_type', False))
|
||||
if (config.sharding_in_types.value and hasattr(x, 'sharding') and
|
||||
isinstance(x.sharding, NamedSharding)):
|
||||
return aval.update(sharding=NamedSharding(
|
||||
x.sharding.mesh.abstract_mesh,
|
||||
x.sharding.spec._normalized_spec(aval.ndim)))
|
||||
return aval
|
||||
raise TypeError(
|
||||
f"Cannot interpret value of type {typ} as an abstract array; it "
|
||||
"does not have a dtype attribute")
|
||||
|
@ -6552,7 +6552,13 @@ def _const(example, val):
|
||||
return np.array(val, dtype)
|
||||
|
||||
_zeros: Callable = partial(full_like, fill_value=0)
|
||||
_zero: Callable = partial(full_like, shape=(), fill_value=0)
|
||||
|
||||
def _zero(x):
|
||||
if config.sharding_in_types.value:
|
||||
return full_like(x, shape=(), fill_value=0,
|
||||
sharding=x.sharding.with_spec(P())) # type: ignore
|
||||
return full_like(x, shape=(), fill_value=0)
|
||||
|
||||
_ones: Callable = partial(full_like, fill_value=1)
|
||||
|
||||
def _one(x):
|
||||
|
@ -22,6 +22,7 @@ from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.util import safe_zip
|
||||
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
@ -46,6 +47,13 @@ def standard_primitive(shape_rule, dtype_rule, name,
|
||||
|
||||
def _get_array_abstraction_level(a): return a.array_abstraction_level
|
||||
|
||||
def call_sharding_rule(rule, num_out, *avals, **kwargs):
|
||||
if config.sharding_in_types.value:
|
||||
if rule is None and mesh_lib.get_abstract_mesh()._are_all_axes_auto: # type: ignore
|
||||
return None if num_out is None else [None] * num_out
|
||||
return rule(*avals, **kwargs)
|
||||
return None if num_out is None else [None] * num_out
|
||||
|
||||
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
|
||||
sharding_rule, *avals, **kwargs):
|
||||
assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
|
||||
@ -57,8 +65,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
|
||||
out_aval = core.ShapedArray(
|
||||
shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
|
||||
weak_type=weak_type,
|
||||
sharding=(sharding_rule(*avals, **kwargs)
|
||||
if config.sharding_in_types.value else None))
|
||||
sharding=call_sharding_rule(sharding_rule, None, *avals, **kwargs))
|
||||
core.check_avals_context_mesh([out_aval], prim.name)
|
||||
return out_aval
|
||||
elif least_specialized is core.DShapedArray:
|
||||
@ -82,9 +89,8 @@ def standard_multi_result_abstract_eval(
|
||||
out_shapes = shape_rule(*avals, **kwargs)
|
||||
out_dtypes = dtype_rule(*avals, **kwargs)
|
||||
core.check_avals_context_mesh(avals, prim.name)
|
||||
out_shardings = (sharding_rule(*avals, **kwargs)
|
||||
if config.sharding_in_types.value else
|
||||
[None] * len(out_shapes))
|
||||
out_shardings = call_sharding_rule(
|
||||
sharding_rule, len(out_shapes), *avals, **kwargs)
|
||||
out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh)
|
||||
for s, d, weak_type, sh in zip(out_shapes, out_dtypes,
|
||||
weak_types, out_shardings)]
|
||||
|
@ -452,6 +452,10 @@ class AbstractMesh:
|
||||
new_axis_types = axis_types_to_names(updated_name_to_type)
|
||||
return AbstractMesh(self.shape_tuple, axis_types=new_axis_types)
|
||||
|
||||
@property
|
||||
def abstract_mesh(self):
|
||||
return self
|
||||
|
||||
@functools.cached_property
|
||||
def _are_all_axes_collective(self) -> bool:
|
||||
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
|
||||
|
@ -4776,7 +4776,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2))
|
||||
|
||||
lowered_text = f.lower(arr).as_text()
|
||||
sds = jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=s)
|
||||
lowered_text = f.lower(sds).as_text()
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertEqual(lowered_text.count('sdy.sharding_constraint'), 3)
|
||||
else:
|
||||
@ -4793,6 +4794,8 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
out = jax.jit(jax.grad(g))(arr)
|
||||
self.assertEqual(out.sharding, arr.sharding)
|
||||
|
||||
jax.jit(jax.grad(g)).lower(sds) # doesn't crash
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_fully_replicated_array_mul(self, mesh):
|
||||
np_inp1 = np.arange(16).reshape(8, 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user