mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix scalar broadcast bug in safe_mul translation
This commit is contained in:
parent
87143d8e0a
commit
b041435f29
@ -1732,8 +1732,8 @@ ad.defbilinear_broadcasting(_brcast, mul_p, mul, mul) # TODO
|
||||
def _safe_mul_translation_rule(c, x, y):
|
||||
dtype = c.GetShape(x).numpy_dtype()
|
||||
zero = c.Constant(onp.array(0, dtype=dtype))
|
||||
out_shape = tuple(onp.maximum(c.GetShape(x).dimensions(),
|
||||
c.GetShape(y).dimensions()))
|
||||
out_shape = broadcast_shapes(c.GetShape(x).dimensions(),
|
||||
c.GetShape(y).dimensions())
|
||||
return c.Select(c.Or(c.Eq(x, zero), c.Eq(y, zero)),
|
||||
c.Broadcast(zero, out_shape),
|
||||
c.Mul(x, y))
|
||||
|
Loading…
x
Reference in New Issue
Block a user