1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00

[sharding_in_types] Add clamp_p sharding rule.

PiperOrigin-RevId: 720428881
This commit is contained in:
Yash Katariya 2025-01-27 21:57:36 -08:00 committed by jax authors
parent ae705fef9c
commit 7ed7e0b5b1

@ -2149,7 +2149,7 @@ def full_like(x: ArrayLike | DuckTypedArray,
if dtypes.issubdtype(dtype, dtypes.extended):
return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr]
if (config.sharding_in_types.value and sharding is None and
if (config.sharding_in_types.value and sharding is None and shape is None and
isinstance(x, Array)):
sharding = x.aval.sharding
else:
@ -4577,6 +4577,9 @@ def _clamp_shape_rule(min, operand, max):
f"(), got max.shape={max.shape}, {operand.shape=}.")
return operand.shape
def _clamp_sharding_rule(min, operand, max):
return operand.sharding
_clamp_dtype_rule = partial(naryop_dtype_rule, _input_dtype, [_any, _any, _any],
'clamp')
@ -4617,7 +4620,8 @@ def _clamp_batch_rule(batched_args, batch_dims, **params):
x = broadcast(x, min.shape)
return clamp_p.bind(min, x, max), 0
clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp')
clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp',
sharding_rule=_clamp_sharding_rule)
ad.defjvp(clamp_p,
lambda g, min, operand, max:
select(bitwise_and(gt(min, operand), lt(min, max)),