mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix a couple ad_util.Zero type checks
This commit is contained in:
parent
fb8b3a4df9
commit
866c17c32e
@ -2871,7 +2871,7 @@ def _concatenate_translation_rule(c, *operands, **kwargs):
|
||||
def _concatenate_transpose_rule(t, *operands, dimension):
|
||||
operand_shapes = [o.aval.shape if ad.is_undefined_primal(o) else o.shape
|
||||
for o in operands]
|
||||
if t is ad_util.Zero:
|
||||
if type(t) is ad_util.Zero:
|
||||
return ad_util.Zero
|
||||
else:
|
||||
limit_points = onp.cumsum([shape[dimension] for shape in operand_shapes])
|
||||
@ -2916,7 +2916,7 @@ def _pad_shape_rule(operand, padding_value, *, padding_config):
|
||||
return tuple(out_shape)
|
||||
|
||||
def _pad_transpose(t, operand, padding_value, *, padding_config):
|
||||
if t is ad_util.Zero:
|
||||
if type(t) is ad_util.Zero:
|
||||
return ad_util.Zero
|
||||
|
||||
lo, hi, interior = zip(*padding_config)
|
||||
|
Loading…
x
Reference in New Issue
Block a user