fix scalar broadcast bug in safe_mul translation

This commit is contained in:
Matthew Johnson 2019-03-23 17:10:01 -07:00
parent 87143d8e0a
commit b041435f29

View File

@ -1732,8 +1732,8 @@ ad.defbilinear_broadcasting(_brcast, mul_p, mul, mul) # TODO
def _safe_mul_translation_rule(c, x, y):
dtype = c.GetShape(x).numpy_dtype()
zero = c.Constant(onp.array(0, dtype=dtype))
out_shape = tuple(onp.maximum(c.GetShape(x).dimensions(),
c.GetShape(y).dimensions()))
out_shape = broadcast_shapes(c.GetShape(x).dimensions(),
c.GetShape(y).dimensions())
return c.Select(c.Or(c.Eq(x, zero), c.Eq(y, zero)),
c.Broadcast(zero, out_shape),
c.Mul(x, y))