mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[sharding_in_types] Make sharding
arg to ShapedArray kwarg only
PiperOrigin-RevId: 726272943
This commit is contained in:
parent
15cd83ae00
commit
3ec7a67e51
@ -1839,7 +1839,7 @@ class ShapedArray(UnshapedArray):
|
||||
__slots__ = ['shape', 'sharding'] # inherits slots from parent
|
||||
array_abstraction_level = 2
|
||||
|
||||
def __init__(self, shape, dtype, weak_type=False, sharding=None):
|
||||
def __init__(self, shape, dtype, weak_type=False, *, sharding=None):
|
||||
self.shape = canonicalize_shape(shape)
|
||||
self.dtype = _dtype_object(dtype)
|
||||
self.weak_type = weak_type
|
||||
@ -1883,7 +1883,7 @@ class ShapedArray(UnshapedArray):
|
||||
def to_tangent_aval(self):
|
||||
if config.sharding_in_types.value:
|
||||
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type, self.sharding)
|
||||
self.weak_type, sharding=self.sharding)
|
||||
else:
|
||||
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type)
|
||||
|
Loading…
x
Reference in New Issue
Block a user