mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #8499 from jakevdp:load-wrapper
PiperOrigin-RevId: 410303891
This commit is contained in:
commit
9491414cf2
@ -449,7 +449,6 @@ array_repr = np.array_repr
|
||||
|
||||
save = np.save
|
||||
savez = np.savez
|
||||
load = np.load
|
||||
|
||||
|
||||
### utility functions
|
||||
@ -633,6 +632,18 @@ def _convert_and_clip_integer(val, dtype):
|
||||
def _constant_like(x, const):
|
||||
return np.array(const, dtype=_dtype(x))
|
||||
|
||||
@_wraps(np.load, update_doc=False)
|
||||
def load(*args, **kwargs):
|
||||
# The main purpose of this wrapper is to recover bfloat16 data types.
|
||||
# Note: this will only work for files created via np.save(), not np.savez().
|
||||
out = np.load(*args, **kwargs)
|
||||
if isinstance(out, np.ndarray):
|
||||
# numpy does not recognize bfloat16, so arrays are serialized as void16
|
||||
if out.dtype == 'V2':
|
||||
out = out.view(bfloat16)
|
||||
out = asarray(out)
|
||||
return out
|
||||
|
||||
### implementations of numpy functions in terms of lax
|
||||
|
||||
@_wraps(np.fmin)
|
||||
|
@ -17,6 +17,7 @@ import collections
|
||||
import functools
|
||||
from functools import partial
|
||||
import inspect
|
||||
import io
|
||||
import itertools
|
||||
import operator
|
||||
from typing import cast, Iterator, Optional, List, Tuple
|
||||
@ -520,6 +521,18 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
func()
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(dtype), "dtype": dtype}
|
||||
for dtype in float_dtypes))
|
||||
def testLoad(self, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
arr = rng((10), dtype)
|
||||
with io.BytesIO() as f:
|
||||
jnp.save(f, arr)
|
||||
f.seek(0)
|
||||
arr_out = jnp.load(f)
|
||||
self.assertArraysEqual(arr, arr_out)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
|
||||
|
Loading…
x
Reference in New Issue
Block a user