From abcc7fdf4c18a2e20a31355c64fc767867703c1c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 12 Mar 2025 08:28:21 -0700 Subject: [PATCH] [sharding_in_types] Initial commit to add `varying_manual_axes: frozenset[AxisName]` to ShapedArray. Also add `jax_varying_axes_in_types` config to hide this option under while we develop it. PiperOrigin-RevId: 736141670 --- jax/_src/config.py | 9 +++++++++ jax/_src/core.py | 23 +++++++++++++++++------ jax/_src/pallas/core.py | 2 +- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 00f65726a..1e46fb8bd 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -235,6 +235,7 @@ def trace_context(): threefry_partitionable.value, threefry_gpu_kernel_lowering.value, use_direct_linearize.value, + varying_axes_in_types.value, softmax_custom_jvp.value, disable_jit.value, debug_key_reuse.value, @@ -1084,6 +1085,14 @@ use_direct_linearize = bool_state( help=('Use direct linearization instead JVP followed by partial eval'), include_in_jit_key=True) +varying_axes_in_types = bool_state( + name='jax_varying_axes_in_types', + default=False, + help=('Adds varying manual axes to ShapedArray to track which mesh axes the' + ' array is varying over. This will help to remove the efficient' + ' transpose rewrite machinery in shard_map'), + include_in_jit_key=True) + data_dependent_tracing_fallback = bool_state( name='jax_data_dependent_tracing_fallback', default=False, diff --git a/jax/_src/core.py b/jax/_src/core.py index b17e26255..e53aec755 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1893,14 +1893,17 @@ def get_sharding(sharding, shape): class ShapedArray(UnshapedArray): - __slots__ = ['shape', 'sharding'] # inherits slots from parent + __slots__ = ['shape', 'sharding', 'varying_manual_axes'] # inherits slots from parent array_abstraction_level = 2 - def __init__(self, shape, dtype, weak_type=False, *, sharding=None): + def __init__(self, shape, dtype, weak_type=False, *, sharding=None, + varying_manual_axes: frozenset[AxisName] = frozenset()): self.shape = canonicalize_shape(shape) self.dtype = _dtype_object(dtype) self.weak_type = weak_type self.sharding = get_sharding(sharding, self.shape) + if config.varying_axes_in_types.value: + self.varying_manual_axes = varying_manual_axes def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: @@ -1911,6 +1914,9 @@ class ShapedArray(UnshapedArray): weak_type = self.weak_type if 'sharding' not in kwargs: kwargs['sharding'] = self.sharding + if 'varying_manual_axes' not in kwargs: + kwargs['varying_manual_axes'] = getattr(self, 'varying_manual_axes', + frozenset()) return ShapedArray(shape, dtype, weak_type, **kwargs) ndim = property(lambda self: len(self.shape)) @@ -1927,17 +1933,22 @@ class ShapedArray(UnshapedArray): return (type(self) is type(other) and self.dtype == other.dtype and self.shape == other.shape and self.weak_type == other.weak_type - and self.sharding == other.sharding) + and self.sharding == other.sharding + and (getattr(self, 'varying_manual_axes', frozenset()) == + getattr(other, 'varying_manual_axes', frozenset()))) def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) - return hash((self.shape, self.dtype, self.weak_type, self.sharding)) + return hash((self.shape, self.dtype, self.weak_type, self.sharding, + getattr(self, 'varying_manual_axes', frozenset()))) def to_tangent_aval(self): - return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type, sharding=self.sharding) + return ShapedArray( + self.shape, primal_dtype_to_tangent_dtype(self.dtype), + self.weak_type, sharding=self.sharding, + varying_manual_axes=getattr(self, 'varying_manual_axes', frozenset())) def str_short(self, short_dtypes=False, mesh_axis_types=False): dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index ad6ce6ab4..466f6037a 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -343,7 +343,7 @@ class BlockSpec: if self.block_shape is None: block_shape = array_aval.shape else: - block_shape = self.block_shape + block_shape = self.block_shape # type: ignore if len(array_aval.shape) != len(block_shape): raise ValueError( f"Block shape for {origin} (= {block_shape}) "