--
bf15ba5310d5f9009571928f70548bcbc7e856c3 by Matthew Johnson <mattjj@google.com>:
don't device transfer in convert_element_type
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
PiperOrigin-RevId: 363995032
fixed some bugs in the bicgstab method and adjusted tolerance for scipy comparison
fixed flake8
added some tests for gradients, fixed symmetry checks, modified lax.cond -> jnp.where
comment out gmres grad check, to be addressed on future PR
increasing tolerance for bicgstab grad test
change to order 1 checks for bicgstab (gmres still fails in order 1) for internal CI check
remove grad checks for now
changing tolerance to pass numpy comparison test
Similarly to `jnp.einsum`, whenever we encounter an extension to the
positional NumPy API (in the case of reductions, the extension is
whenever a non-integer axis is specified), we reroute the call to a
parallel primitive instead of the standard lax reductions.
Note that this makes the parallel primitives implement a strict subset
of functionality of the lax reductions so in the future (when we decide
that we want axes to be truly first class) we can always swap out the
implementation for the parallel version. But, it makes sense to keep
them separate for the ease of prototyping in the near future.
We aren't supporting eigenvectors for now because eigenvectors are not
uniquely determined by the input matrix, they're only determined up to
'gauge' (that is multiplication by a complex scalar with absolute value
1). Note, this means that second derivatives aren't supported, because
they involve differentiating the eigvals jvp, which itself depends on
eigenvectors.
Because we now have a facade around the lax library, we can expose the lax_linalg primitives directly in lax without creating circular dependency problems.
Leave a few forwarding stubs to be removed later.
PiperOrigin-RevId: 340658800