George Necula 8e931a82ff [shape_poly] Generalize binary operations with symbolic dimensions.
Previously binary operations involving symbolic dimensions would
work only when the other operand is convertible to a symbolic dimension,
e.g., an integer. This resulted in errors when trying "x.shape[0] * 3.5"
and the recourse was to ask the user to add an explicit
"jnp.array(x.shape[0])".

Now we allow binary operations with any operand and the
"jnp.array" is added automatically if the other operand is not
an integer or a symbolic dimension. This means that instead
of an error they may be an error downstream if one tries to use
the result as a dimension. There is one known case where
JAX works with static shapes and with the previous behavior,
but will fail now. When you operate on `np.ndarray` and
symbolic dimension, previously this was kept as a `np.ndarray`
but not it is turned into a JAX array. The following
program will now fail if `x.shape[0]` is a symbolic dimension.:

`jnp.ones(np.arange(5) * x.shape[0])`

Instead you should write

`jnp.ones([i * x.shape[0] for i in range(5)])`
2023-01-21 04:26:59 -08:00
..
2022-12-05 15:42:26 -08:00