mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[sharding_in_types] Initial commit to add varying_manual_axes: frozenset[AxisName]
to ShapedArray. Also add jax_varying_axes_in_types
config to hide this option under while we develop it.
PiperOrigin-RevId: 736141670
This commit is contained in:
parent
8b7cfcb33c
commit
abcc7fdf4c
@ -235,6 +235,7 @@ def trace_context():
|
||||
threefry_partitionable.value,
|
||||
threefry_gpu_kernel_lowering.value,
|
||||
use_direct_linearize.value,
|
||||
varying_axes_in_types.value,
|
||||
softmax_custom_jvp.value,
|
||||
disable_jit.value,
|
||||
debug_key_reuse.value,
|
||||
@ -1084,6 +1085,14 @@ use_direct_linearize = bool_state(
|
||||
help=('Use direct linearization instead JVP followed by partial eval'),
|
||||
include_in_jit_key=True)
|
||||
|
||||
varying_axes_in_types = bool_state(
|
||||
name='jax_varying_axes_in_types',
|
||||
default=False,
|
||||
help=('Adds varying manual axes to ShapedArray to track which mesh axes the'
|
||||
' array is varying over. This will help to remove the efficient'
|
||||
' transpose rewrite machinery in shard_map'),
|
||||
include_in_jit_key=True)
|
||||
|
||||
data_dependent_tracing_fallback = bool_state(
|
||||
name='jax_data_dependent_tracing_fallback',
|
||||
default=False,
|
||||
|
@ -1893,14 +1893,17 @@ def get_sharding(sharding, shape):
|
||||
|
||||
|
||||
class ShapedArray(UnshapedArray):
|
||||
__slots__ = ['shape', 'sharding'] # inherits slots from parent
|
||||
__slots__ = ['shape', 'sharding', 'varying_manual_axes'] # inherits slots from parent
|
||||
array_abstraction_level = 2
|
||||
|
||||
def __init__(self, shape, dtype, weak_type=False, *, sharding=None):
|
||||
def __init__(self, shape, dtype, weak_type=False, *, sharding=None,
|
||||
varying_manual_axes: frozenset[AxisName] = frozenset()):
|
||||
self.shape = canonicalize_shape(shape)
|
||||
self.dtype = _dtype_object(dtype)
|
||||
self.weak_type = weak_type
|
||||
self.sharding = get_sharding(sharding, self.shape)
|
||||
if config.varying_axes_in_types.value:
|
||||
self.varying_manual_axes = varying_manual_axes
|
||||
|
||||
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
|
||||
if shape is None:
|
||||
@ -1911,6 +1914,9 @@ class ShapedArray(UnshapedArray):
|
||||
weak_type = self.weak_type
|
||||
if 'sharding' not in kwargs:
|
||||
kwargs['sharding'] = self.sharding
|
||||
if 'varying_manual_axes' not in kwargs:
|
||||
kwargs['varying_manual_axes'] = getattr(self, 'varying_manual_axes',
|
||||
frozenset())
|
||||
return ShapedArray(shape, dtype, weak_type, **kwargs)
|
||||
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
@ -1927,17 +1933,22 @@ class ShapedArray(UnshapedArray):
|
||||
return (type(self) is type(other)
|
||||
and self.dtype == other.dtype and self.shape == other.shape
|
||||
and self.weak_type == other.weak_type
|
||||
and self.sharding == other.sharding)
|
||||
and self.sharding == other.sharding
|
||||
and (getattr(self, 'varying_manual_axes', frozenset()) ==
|
||||
getattr(other, 'varying_manual_axes', frozenset())))
|
||||
|
||||
def __hash__(self):
|
||||
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
|
||||
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
|
||||
# the unique character code via hash(self.dtype.char)
|
||||
return hash((self.shape, self.dtype, self.weak_type, self.sharding))
|
||||
return hash((self.shape, self.dtype, self.weak_type, self.sharding,
|
||||
getattr(self, 'varying_manual_axes', frozenset())))
|
||||
|
||||
def to_tangent_aval(self):
|
||||
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type, sharding=self.sharding)
|
||||
return ShapedArray(
|
||||
self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type, sharding=self.sharding,
|
||||
varying_manual_axes=getattr(self, 'varying_manual_axes', frozenset()))
|
||||
|
||||
def str_short(self, short_dtypes=False, mesh_axis_types=False):
|
||||
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else
|
||||
|
@ -343,7 +343,7 @@ class BlockSpec:
|
||||
if self.block_shape is None:
|
||||
block_shape = array_aval.shape
|
||||
else:
|
||||
block_shape = self.block_shape
|
||||
block_shape = self.block_shape # type: ignore
|
||||
if len(array_aval.shape) != len(block_shape):
|
||||
raise ValueError(
|
||||
f"Block shape for {origin} (= {block_shape}) "
|
||||
|
Loading…
x
Reference in New Issue
Block a user