mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Better documentation for jnp.load
This commit is contained in:
parent
884f1dc3a1
commit
0a85ba5f82
@ -32,10 +32,11 @@ from functools import partial
|
||||
import importlib
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
import string
|
||||
import types
|
||||
from typing import ( Any, Literal, NamedTuple,
|
||||
Protocol, TypeVar, Union,overload)
|
||||
from typing import (Any, IO, Literal, NamedTuple,
|
||||
Protocol, TypeVar, Union, overload)
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
@ -320,11 +321,43 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array:
|
||||
return clip(val, min_val, max_val).astype(dtype)
|
||||
|
||||
|
||||
@util.implements(np.load, update_doc=False)
|
||||
def load(*args: Any, **kwargs: Any) -> Array:
|
||||
def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> Array:
|
||||
"""Load JAX arrays from npy files.
|
||||
|
||||
JAX wrapper of :func:`numpy.load`.
|
||||
|
||||
This function is a simple wrapper of :func:`numpy.load`, but in the case of
|
||||
``.npy`` files created with :func:`numpy.save` or :func:`jax.numpy.save`,
|
||||
the output will be returned as a :class:`jax.Array`, and ``bfloat16`` data
|
||||
types will be restored. For ``.npz`` files, results will be returned as
|
||||
normal NumPy arrays.
|
||||
|
||||
This function requires concrete array inputs, and is not compatible with
|
||||
transformations like :func:`jax.jit` or :func:`jax.vmap`.
|
||||
|
||||
Args:
|
||||
file: string, bytes, or path-like object containing the array data.
|
||||
args, kwargs: for additional arguments, see :func:`numpy.load`
|
||||
|
||||
Returns:
|
||||
the array stored in the file.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.save`: save an array to a file.
|
||||
|
||||
Examples:
|
||||
>>> import io
|
||||
>>> f = io.BytesIO() # use an in-memory file-like object.
|
||||
>>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16')
|
||||
>>> jnp.save(f, x)
|
||||
>>> f.seek(0)
|
||||
0
|
||||
>>> jnp.load(f)
|
||||
Array([2, 4, 6, 8], dtype=bfloat16)
|
||||
"""
|
||||
# The main purpose of this wrapper is to recover bfloat16 data types.
|
||||
# Note: this will only work for files created via np.save(), not np.savez().
|
||||
out = np.load(*args, **kwargs)
|
||||
out = np.load(file, *args, **kwargs)
|
||||
if isinstance(out, np.ndarray):
|
||||
# numpy does not recognize bfloat16, so arrays are serialized as void16
|
||||
if out.dtype == 'V2':
|
||||
|
@ -3,7 +3,8 @@ from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Literal, NamedTuple, Protocol, TypeVar, Union, overload
|
||||
import os
|
||||
from typing import Any, IO, Literal, NamedTuple, Protocol, TypeVar, Union, overload
|
||||
|
||||
from jax._src import core as _core
|
||||
from jax._src import dtypes as _dtypes
|
||||
@ -577,7 +578,7 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
|
||||
axis: int = 0,
|
||||
*, device: _Device | _Sharding | None = ...) -> Union[Array, tuple[Array, Array]]: ...
|
||||
|
||||
def load(*args: Any, **kwargs: Any) -> Array: ...
|
||||
def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> Array: ...
|
||||
def log(x: ArrayLike, /) -> Array: ...
|
||||
def log10(x: ArrayLike, /) -> Array: ...
|
||||
def log1p(x: ArrayLike, /) -> Array: ...
|
||||
|
Loading…
x
Reference in New Issue
Block a user