We should not silently convert non-array inputs to arrays, because this can lead to silent performance degredation. This brings the sparse_plus API in line with other APIs in this module.
PiperOrigin-RevId: 617190413
Currently distribution parameters such as stddev and scale are expected to be
weakly typed scalars. When they're passed as float32 they can cause an upcast
of the initialized arrays even when the dtype is specified as e.g. bfloat16.
Some users were surprised by this.
PiperOrigin-RevId: 611858446
Without this decorator, we get a warning from typeguard:
```
.../typeguard/_checkers.py:474: UserWarning: Typeguard cannot check the Initializer protocol because it is a non-runtime protocol. If you would like to type check this protocol, please use @typing.runtime_checkable
```
PiperOrigin-RevId: 598588778
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
In the common case (real values) these are all single-expression jaxprs themselves, so putting them out of line just makes things more verbose.
There's no reason to include stuff like this in a jaxpr:
```
cxd:bool[8,16] = pjit[
jaxpr={ lambda ; cxe:f32[8,16]. let
cxf:bool[8,16] = is_finite cxe
in (cxf,) }
name=isfinite
] cxc
```
PiperOrigin-RevId: 587047955
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
PiperOrigin-RevId: 572587137
I had to revert to using `Any` for `RaggedAxis.ragged_axes` because pytype
found more latent type errors, which require the understanding of ragedness
and dynamic shapes internals to fix properly.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:
* remove symbolic_equal_dim and symbolic_equal_shape in favor of the
newer definitely_equal and definitely_equal_shape
* remove is_special_dim_size, which checks that a value is a
dimension expression (not a constant). Some uses are replaced
with `not is_constant_dim` and others with `is_dim`.
* introduce concrete_dim_or_error to check that a value is
a dimension
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
Unlike the previous attempt, we don't try to use mhlo.logistic as the lowering of the new primitive yet. Instead, we lower to the old implementation of `expit`. This means that this change should be a no-op numerically and we can work on changing its implementation in a subsequent change.
PiperOrigin-RevId: 472705623