[sharding_in_types] Initial commit to add varying_manual_axes: frozenset[AxisName] to ShapedArray. Also add jax_varying_axes_in_types config to hide this option under while we develop it.

PiperOrigin-RevId: 736141670
This commit is contained in:
Yash Katariya 2025-03-12 08:28:21 -07:00 committed by jax authors
parent 8b7cfcb33c
commit abcc7fdf4c
3 changed files with 27 additions and 7 deletions

View File

@ -235,6 +235,7 @@ def trace_context():
threefry_partitionable.value,
threefry_gpu_kernel_lowering.value,
use_direct_linearize.value,
varying_axes_in_types.value,
softmax_custom_jvp.value,
disable_jit.value,
debug_key_reuse.value,
@ -1084,6 +1085,14 @@ use_direct_linearize = bool_state(
help=('Use direct linearization instead JVP followed by partial eval'),
include_in_jit_key=True)
varying_axes_in_types = bool_state(
name='jax_varying_axes_in_types',
default=False,
help=('Adds varying manual axes to ShapedArray to track which mesh axes the'
' array is varying over. This will help to remove the efficient'
' transpose rewrite machinery in shard_map'),
include_in_jit_key=True)
data_dependent_tracing_fallback = bool_state(
name='jax_data_dependent_tracing_fallback',
default=False,

View File

@ -1893,14 +1893,17 @@ def get_sharding(sharding, shape):
class ShapedArray(UnshapedArray):
__slots__ = ['shape', 'sharding'] # inherits slots from parent
__slots__ = ['shape', 'sharding', 'varying_manual_axes'] # inherits slots from parent
array_abstraction_level = 2
def __init__(self, shape, dtype, weak_type=False, *, sharding=None):
def __init__(self, shape, dtype, weak_type=False, *, sharding=None,
varying_manual_axes: frozenset[AxisName] = frozenset()):
self.shape = canonicalize_shape(shape)
self.dtype = _dtype_object(dtype)
self.weak_type = weak_type
self.sharding = get_sharding(sharding, self.shape)
if config.varying_axes_in_types.value:
self.varying_manual_axes = varying_manual_axes
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
if shape is None:
@ -1911,6 +1914,9 @@ class ShapedArray(UnshapedArray):
weak_type = self.weak_type
if 'sharding' not in kwargs:
kwargs['sharding'] = self.sharding
if 'varying_manual_axes' not in kwargs:
kwargs['varying_manual_axes'] = getattr(self, 'varying_manual_axes',
frozenset())
return ShapedArray(shape, dtype, weak_type, **kwargs)
ndim = property(lambda self: len(self.shape))
@ -1927,17 +1933,22 @@ class ShapedArray(UnshapedArray):
return (type(self) is type(other)
and self.dtype == other.dtype and self.shape == other.shape
and self.weak_type == other.weak_type
and self.sharding == other.sharding)
and self.sharding == other.sharding
and (getattr(self, 'varying_manual_axes', frozenset()) ==
getattr(other, 'varying_manual_axes', frozenset())))
def __hash__(self):
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
# the unique character code via hash(self.dtype.char)
return hash((self.shape, self.dtype, self.weak_type, self.sharding))
return hash((self.shape, self.dtype, self.weak_type, self.sharding,
getattr(self, 'varying_manual_axes', frozenset())))
def to_tangent_aval(self):
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type, sharding=self.sharding)
return ShapedArray(
self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type, sharding=self.sharding,
varying_manual_axes=getattr(self, 'varying_manual_axes', frozenset()))
def str_short(self, short_dtypes=False, mesh_axis_types=False):
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else

View File

@ -343,7 +343,7 @@ class BlockSpec:
if self.block_shape is None:
block_shape = array_aval.shape
else:
block_shape = self.block_shape
block_shape = self.block_shape # type: ignore
if len(array_aval.shape) != len(block_shape):
raise ValueError(
f"Block shape for {origin} (= {block_shape}) "