Remove jax.partial from the JAX API.

Use functools.partial instead.
This commit is contained in:
Peter Hawkins 2021-09-16 15:46:26 -04:00
parent ab464bd70c
commit f35ab3693d
2 changed files with 3 additions and 3 deletions

View File

@ -15,8 +15,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* `jax.api` has been removed. Functions that were available as `jax.api.*`
were aliases for functions in `jax.*`; please use the functions in
`jax.*` instead.
* `jax.lax.partial` was an accidental export that has now been removed. Use
`functools.partial` instead.
* `jax.partial` and `jax.lax.partial` were accidental exports that have now
been removed. Use `functools.partial` from the Python standard library
instead.
* Boolean scalar indices now raise a `TypeError`; previously this silently
returned wrong results ({jax-issue}`#7925`).
* Many more `jax.numpy` functions now require array-like inputs, and will error

View File

@ -112,7 +112,6 @@ from ._src.api import (
xla, # TODO(phawkins): update users to avoid this.
xla_computation as xla_computation,
)
from functools import partial as partial # TODO(phawkins): remove this export.
from .experimental.maps import soft_pmap as soft_pmap
from .version import __version__ as __version__