Merge pull request #8499 from jakevdp:load-wrapper

PiperOrigin-RevId: 410303891
This commit is contained in:
jax authors 2021-11-16 11:25:44 -08:00
commit 9491414cf2
2 changed files with 25 additions and 1 deletions

View File

@ -449,7 +449,6 @@ array_repr = np.array_repr
save = np.save
savez = np.savez
load = np.load
### utility functions
@ -633,6 +632,18 @@ def _convert_and_clip_integer(val, dtype):
def _constant_like(x, const):
return np.array(const, dtype=_dtype(x))
@_wraps(np.load, update_doc=False)
def load(*args, **kwargs):
# The main purpose of this wrapper is to recover bfloat16 data types.
# Note: this will only work for files created via np.save(), not np.savez().
out = np.load(*args, **kwargs)
if isinstance(out, np.ndarray):
# numpy does not recognize bfloat16, so arrays are serialized as void16
if out.dtype == 'V2':
out = out.view(bfloat16)
out = asarray(out)
return out
### implementations of numpy functions in terms of lax
@_wraps(np.fmin)

View File

@ -17,6 +17,7 @@ import collections
import functools
from functools import partial
import inspect
import io
import itertools
import operator
from typing import cast, Iterator, Optional, List, Tuple
@ -520,6 +521,18 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
with self.assertRaises(NotImplementedError):
func()
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(dtype), "dtype": dtype}
for dtype in float_dtypes))
def testLoad(self, dtype):
rng = jtu.rand_default(self.rng())
arr = rng((10), dtype)
with io.BytesIO() as f:
jnp.save(f, arr)
f.seek(0)
arr_out = jnp.load(f)
self.assertArraysEqual(arr, arr_out)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,