mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix symbolic zero handling in _pad_transpose
tested manually against example from @matthewdhoffman
This commit is contained in:
parent
6434d9d029
commit
0600b738f4
@ -2386,8 +2386,11 @@ def _pad_shape_rule(operand, padding_value, padding_config):
|
||||
return tuple(out_shape)
|
||||
|
||||
def _pad_transpose(t, operand, padding_value, padding_config):
|
||||
lo, hi, interior = zip(*padding_config)
|
||||
if t is ad_util.zero:
|
||||
return [ad_util.zero if operand is None else None,
|
||||
ad_util.zero if padding_value is None else None]
|
||||
|
||||
lo, hi, interior = zip(*padding_config)
|
||||
total = lambda x: _reduce_sum(x, list(range(t.ndim)))
|
||||
|
||||
def t_op():
|
||||
|
Loading…
x
Reference in New Issue
Block a user