Currently, JAX is generating random 8 bit ints for bools, which usually doesn't cause any issues, but in some special cases does. One example is the HLO snapshot dumping code, which surprisingly creates unparseable protos for such inputs.
PiperOrigin-RevId: 513032802
- jax.Array.broadcast: use lax.broadcast instead
- jax.Array.broadcast_in_dim: use lax.broadcast_in_dim instead
- jax.Array.split: use jnp.split instead
These are removed because they are not part of the np.ndarray API.
These don't seem necessary now JAX has fairly complete coverage of the NumPy API. Also removes the accidental export of _NOT_IMPLEMENTED in several modules.
See current behavior difference wrt np.nan_to_num
```
>>> np.nan_to_num(np.array(1, dtype=np.int32))
1
>>> jnp.nan_to_num(jnp.array(1, dtype=jnp.int32))
ValueError: data type <class 'numpy.int32'> not inexact
```
PiperOrigin-RevId: 505735212
Previously, division was only supported in certain situation, and this
led to errors, e.g., when using strides. Now we generalize the polynomials
to also include "floordiv(E, E)" and "mod(E, E)" as atoms, in addition
to dimension variables. A symbolic dimension is now a sum of products
of atoms. (We also changed the documentation to use symbolic dimension
instead of dimension polynomials).
Previously binary operations involving symbolic dimensions would
work only when the other operand is convertible to a symbolic dimension,
e.g., an integer. This resulted in errors when trying "x.shape[0] * 3.5"
and the recourse was to ask the user to add an explicit
"jnp.array(x.shape[0])".
Now we allow binary operations with any operand and the
"jnp.array" is added automatically if the other operand is not
an integer or a symbolic dimension. This means that instead
of an error they may be an error downstream if one tries to use
the result as a dimension. There is one known case where
JAX works with static shapes and with the previous behavior,
but will fail now. When you operate on `np.ndarray` and
symbolic dimension, previously this was kept as a `np.ndarray`
but not it is turned into a JAX array. The following
program will now fail if `x.shape[0]` is a symbolic dimension.:
`jnp.ones(np.arange(5) * x.shape[0])`
Instead you should write
`jnp.ones([i * x.shape[0] for i in range(5)])`
There was a partial fix before, in #13470, but it was incomplete
and the x64 mode was not handled properly.
There are no tests added here; this was discovered by running the
tests with --jax2tf_default_experimental_native_lowering, which
will become default soon.