Better documentation for jnp.load

This commit is contained in:
Jake VanderPlas 2024-10-19 06:20:20 -07:00
parent 884f1dc3a1
commit 0a85ba5f82
2 changed files with 41 additions and 7 deletions

View File

@ -32,10 +32,11 @@ from functools import partial
import importlib
import math
import operator
import os
import string
import types
from typing import ( Any, Literal, NamedTuple,
Protocol, TypeVar, Union,overload)
from typing import (Any, IO, Literal, NamedTuple,
Protocol, TypeVar, Union, overload)
import warnings
import jax
@ -320,11 +321,43 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array:
return clip(val, min_val, max_val).astype(dtype)
@util.implements(np.load, update_doc=False)
def load(*args: Any, **kwargs: Any) -> Array:
def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> Array:
"""Load JAX arrays from npy files.
JAX wrapper of :func:`numpy.load`.
This function is a simple wrapper of :func:`numpy.load`, but in the case of
``.npy`` files created with :func:`numpy.save` or :func:`jax.numpy.save`,
the output will be returned as a :class:`jax.Array`, and ``bfloat16`` data
types will be restored. For ``.npz`` files, results will be returned as
normal NumPy arrays.
This function requires concrete array inputs, and is not compatible with
transformations like :func:`jax.jit` or :func:`jax.vmap`.
Args:
file: string, bytes, or path-like object containing the array data.
args, kwargs: for additional arguments, see :func:`numpy.load`
Returns:
the array stored in the file.
See also:
- :func:`jax.numpy.save`: save an array to a file.
Examples:
>>> import io
>>> f = io.BytesIO() # use an in-memory file-like object.
>>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16')
>>> jnp.save(f, x)
>>> f.seek(0)
0
>>> jnp.load(f)
Array([2, 4, 6, 8], dtype=bfloat16)
"""
# The main purpose of this wrapper is to recover bfloat16 data types.
# Note: this will only work for files created via np.save(), not np.savez().
out = np.load(*args, **kwargs)
out = np.load(file, *args, **kwargs)
if isinstance(out, np.ndarray):
# numpy does not recognize bfloat16, so arrays are serialized as void16
if out.dtype == 'V2':

View File

@ -3,7 +3,8 @@ from __future__ import annotations
import builtins
from collections.abc import Callable, Sequence
from typing import Any, Literal, NamedTuple, Protocol, TypeVar, Union, overload
import os
from typing import Any, IO, Literal, NamedTuple, Protocol, TypeVar, Union, overload
from jax._src import core as _core
from jax._src import dtypes as _dtypes
@ -577,7 +578,7 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
axis: int = 0,
*, device: _Device | _Sharding | None = ...) -> Union[Array, tuple[Array, Array]]: ...
def load(*args: Any, **kwargs: Any) -> Array: ...
def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> Array: ...
def log(x: ArrayLike, /) -> Array: ...
def log10(x: ArrayLike, /) -> Array: ...
def log1p(x: ArrayLike, /) -> Array: ...