All of these were deprecated prior to the JAX 0.4.16 release, on Sept 18 2023.
As of Monday Dec 18, we have met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).
PiperOrigin-RevId: 591933493
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
I had to revert to using `Any` for `RaggedAxis.ragged_axes` because pytype
found more latent type errors, which require the understanding of ragedness
and dynamic shapes internals to fix properly.
This type stub is intended to match what pytype currently infers for jax.numpy, which is not particularly accurate in many cases. Future changes will add more accurate types to this stub.
Fix a number of new type errors this reveals to mypy.
PiperOrigin-RevId: 559179804
These functions have custom derivatives, so there seems to be no point to using the double-where guard on the primal function: the implementation can never be differentiated!
PiperOrigin-RevId: 551843160
Their behavior is the same as functions in scipy.special. The only small
difference is in rel_entr function, which unlike scipy.special does not
take the optional parameter 'out'.
Resolves#16630
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
`jnp.finfo(...)` of an Array type yields:
```
TypeError: unhashable type: 'ArrayImpl'
```
However, `np.finfo(...)` no longer accepts NumPy arrays as input either, so it would be consistent to require the user to pass a dtype where they are currently passing an array.
PiperOrigin-RevId: 539174254
Fixes the docstring `jax.scipy.special.gamma`, which was wrapping `scipy.special.gammaln` by mistake. Also adds a note that the function currently only accepts real inputs.
- Add gamma fcn api in scipy.special
- Add tests for this purpose
- Add function to the docs
Currently, there is no implementation of the gamma function in jax
but there is one in scipy.special. This breaks some higher level
jit-compilation like in the blackjax backend for pymc. This commit
adds the missing gamma function.
Resolves: #15409
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.