mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 01:46:06 +00:00

Previously binary operations involving symbolic dimensions would work only when the other operand is convertible to a symbolic dimension, e.g., an integer. This resulted in errors when trying "x.shape[0] * 3.5" and the recourse was to ask the user to add an explicit "jnp.array(x.shape[0])". Now we allow binary operations with any operand and the "jnp.array" is added automatically if the other operand is not an integer or a symbolic dimension. This means that instead of an error they may be an error downstream if one tries to use the result as a dimension. There is one known case where JAX works with static shapes and with the previous behavior, but will fail now. When you operate on `np.ndarray` and symbolic dimension, previously this was kept as a `np.ndarray` but not it is turned into a JAX array. The following program will now fail if `x.shape[0]` is a symbolic dimension.: `jnp.ones(np.arange(5) * x.shape[0])` Instead you should write `jnp.ones([i * x.shape[0] for i in range(5)])`