[sharding_in_types] Make sharding arg to ShapedArray kwarg only

PiperOrigin-RevId: 726272943
This commit is contained in:
Yash Katariya 2025-02-12 18:22:15 -08:00 committed by jax authors
parent 15cd83ae00
commit 3ec7a67e51

View File

@ -1839,7 +1839,7 @@ class ShapedArray(UnshapedArray):
__slots__ = ['shape', 'sharding'] # inherits slots from parent
array_abstraction_level = 2
def __init__(self, shape, dtype, weak_type=False, sharding=None):
def __init__(self, shape, dtype, weak_type=False, *, sharding=None):
self.shape = canonicalize_shape(shape)
self.dtype = _dtype_object(dtype)
self.weak_type = weak_type
@ -1883,7 +1883,7 @@ class ShapedArray(UnshapedArray):
def to_tangent_aval(self):
if config.sharding_in_types.value:
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type, self.sharding)
self.weak_type, sharding=self.sharding)
else:
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)