diff --git a/jax/_src/core.py b/jax/_src/core.py index ed1838ced..90020dccc 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1839,7 +1839,7 @@ class ShapedArray(UnshapedArray): __slots__ = ['shape', 'sharding'] # inherits slots from parent array_abstraction_level = 2 - def __init__(self, shape, dtype, weak_type=False, sharding=None): + def __init__(self, shape, dtype, weak_type=False, *, sharding=None): self.shape = canonicalize_shape(shape) self.dtype = _dtype_object(dtype) self.weak_type = weak_type @@ -1883,7 +1883,7 @@ class ShapedArray(UnshapedArray): def to_tangent_aval(self): if config.sharding_in_types.value: return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type, self.sharding) + self.weak_type, sharding=self.sharding) else: return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type)