mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sharding_in_types] Add clamp_p
sharding rule.
PiperOrigin-RevId: 720428881
This commit is contained in:
parent
ae705fef9c
commit
7ed7e0b5b1
@ -2149,7 +2149,7 @@ def full_like(x: ArrayLike | DuckTypedArray,
|
||||
if dtypes.issubdtype(dtype, dtypes.extended):
|
||||
return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr]
|
||||
|
||||
if (config.sharding_in_types.value and sharding is None and
|
||||
if (config.sharding_in_types.value and sharding is None and shape is None and
|
||||
isinstance(x, Array)):
|
||||
sharding = x.aval.sharding
|
||||
else:
|
||||
@ -4577,6 +4577,9 @@ def _clamp_shape_rule(min, operand, max):
|
||||
f"(), got max.shape={max.shape}, {operand.shape=}.")
|
||||
return operand.shape
|
||||
|
||||
def _clamp_sharding_rule(min, operand, max):
|
||||
return operand.sharding
|
||||
|
||||
_clamp_dtype_rule = partial(naryop_dtype_rule, _input_dtype, [_any, _any, _any],
|
||||
'clamp')
|
||||
|
||||
@ -4617,7 +4620,8 @@ def _clamp_batch_rule(batched_args, batch_dims, **params):
|
||||
x = broadcast(x, min.shape)
|
||||
return clamp_p.bind(min, x, max), 0
|
||||
|
||||
clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp')
|
||||
clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp',
|
||||
sharding_rule=_clamp_sharding_rule)
|
||||
ad.defjvp(clamp_p,
|
||||
lambda g, min, operand, max:
|
||||
select(bitwise_and(gt(min, operand), lt(min, max)),
|
||||
|
Loading…
x
Reference in New Issue
Block a user