Previously, division was only supported in certain situation, and this
led to errors, e.g., when using strides. Now we generalize the polynomials
to also include "floordiv(E, E)" and "mod(E, E)" as atoms, in addition
to dimension variables. A symbolic dimension is now a sum of products
of atoms. (We also changed the documentation to use symbolic dimension
instead of dimension polynomials).
Previously binary operations involving symbolic dimensions would
work only when the other operand is convertible to a symbolic dimension,
e.g., an integer. This resulted in errors when trying "x.shape[0] * 3.5"
and the recourse was to ask the user to add an explicit
"jnp.array(x.shape[0])".
Now we allow binary operations with any operand and the
"jnp.array" is added automatically if the other operand is not
an integer or a symbolic dimension. This means that instead
of an error they may be an error downstream if one tries to use
the result as a dimension. There is one known case where
JAX works with static shapes and with the previous behavior,
but will fail now. When you operate on `np.ndarray` and
symbolic dimension, previously this was kept as a `np.ndarray`
but not it is turned into a JAX array. The following
program will now fail if `x.shape[0]` is a symbolic dimension.:
`jnp.ones(np.arange(5) * x.shape[0])`
Instead you should write
`jnp.ones([i * x.shape[0] for i in range(5)])`
There was a partial fix before, in #13470, but it was incomplete
and the x64 mode was not handled properly.
There are no tests added here; this was discovered by running the
tests with --jax2tf_default_experimental_native_lowering, which
will become default soon.
its gradient is taken and one of the inputs is NaN. This CL adds
a short description of the behavior to the jax.numpy.where docs,
which is the logical place that users would look for it.
PiperOrigin-RevId: 488654036
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.
It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.
Fix many test cases that were shown to be broken with a larger number of test cases enabled.
PiperOrigin-RevId: 487406670
Including blackman, bartlett, hamming, hanning, kaiser.
Why? Previously these were implemented by computing the output on host at trace-time and embedding the result as a large constant array. Computing the results via lax operations is more in the spirit of jax.numpy.
Including blackman, bartlett, hamming, hanning, kaiser.
Why? Previously these were implemented by embedding large constants; this should be more performant.
The idea with jnp.canonicalize_shape is that it handles non-tuple shapes, i.e.
intended to be scalar-like arguments like Python builtin ints or numpy scalar
types or 0D arrays. To do that, it checks numpy.ndim(shape) == 0. But
numpy.ndim might attempt to convert its argument to a numpy.ndarray, which
breaks when the argument is a tuple with Tracers inside!
Instead, let's just check if the argument is one of the canonical sequence
types (list or tuple) and if so then not even call numpy.ndim.
Prior to this the user had to explicitly call core.dimension_as_value whenever
using a potentially polymorphic shape in the computation, e.g., x +
core.dimension_as_value(x.shape[0]). Furthermore, jnp.array(x.shape[0])
would fail.
Now, these operations are allowed implicitly,
and the user can call `jnp.array(x.shape[0])`.
This uses an internal extensibility mechanism called __jax_array__
that is experimental and probably not fully implemented.
If e.g. `x : f32[10, n]` then we want to handle Python expressions like `x[0]`.
To do that, we can use a generalized version of `dynamic_slice` which allows
dynamic slice sizes (where the result shape depends on those slice sizes).
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
Added TODOs to take fast path for indices wherever it is possible to do that. If a correct index is passed during getitem and if that index exists on `Array`, then the fast path is taken (see the test in this CL).
PiperOrigin-RevId: 473342504