fix a couple ad_util.Zero type checks

This commit is contained in:
Matthew Johnson 2020-06-08 13:22:13 -07:00
parent fb8b3a4df9
commit 866c17c32e

View File

@ -2871,7 +2871,7 @@ def _concatenate_translation_rule(c, *operands, **kwargs):
def _concatenate_transpose_rule(t, *operands, dimension):
operand_shapes = [o.aval.shape if ad.is_undefined_primal(o) else o.shape
for o in operands]
if t is ad_util.Zero:
if type(t) is ad_util.Zero:
return ad_util.Zero
else:
limit_points = onp.cumsum([shape[dimension] for shape in operand_shapes])
@ -2916,7 +2916,7 @@ def _pad_shape_rule(operand, padding_value, *, padding_config):
return tuple(out_shape)
def _pad_transpose(t, operand, padding_value, *, padding_config):
if t is ad_util.Zero:
if type(t) is ad_util.Zero:
return ad_util.Zero
lo, hi, interior = zip(*padding_config)