mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 02:56:05 +00:00
Merge pull request #10659 from jakevdp:devicearray-pickle
PiperOrigin-RevId: 449995717
This commit is contained in:
commit
87d2474cdf
10
CHANGELOG.md
10
CHANGELOG.md
@ -15,6 +15,16 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
that allows selection between an LU-decomposition based implementation and
|
||||
an implementation based on QR decomposition.
|
||||
* {func}`jax.numpy.linalg.qr` now supports `mode="raw"`.
|
||||
* `pickle`, `copy.copy`, and `copy.deepcopy` now have more complete support when
|
||||
used on jax arrays ({jax-issue}`#10659`). In particular:
|
||||
- `pickle` and `deepcopy` previously returned `np.ndarray` objects when used
|
||||
on a `DeviceArray`; now `DeviceArray` objects are returned. For `deepcopy`,
|
||||
the copied array is on the same device as the original. For `pickle` the
|
||||
deserialized array will be on the default device.
|
||||
- Within function transformations (i.e. traced code), `deepcopy` and `copy`
|
||||
previously were no-ops. Now they use the same mechanism as `DeviceArray.copy()`.
|
||||
- Calling `pickle` on a traced array now results in an explicit
|
||||
`ConcretizationTypeError`.
|
||||
|
||||
## jaxlib 0.3.11 (Unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
|
||||
|
@ -474,6 +474,30 @@ instantiate :class:`DeviceArray` objects manually, but rather will create them v
|
||||
:mod:`jax.numpy` functions like :func:`~jax.numpy.array`, :func:`~jax.numpy.arange`,
|
||||
:func:`~jax.numpy.linspace`, and others listed above.
|
||||
|
||||
Copying and Serialization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
:class:`~jax.numpy.DeviceArray`` objects are designed to work seamlessly with Python
|
||||
standard library tools where appropriate.
|
||||
|
||||
With the built-in :mod:`copy` module, when :func:`copy.copy` or :func:`copy.deepcopy`
|
||||
encounder a :class:`~jax.numpy.DeviceArray`, it is equivalent to calling the
|
||||
:meth:`~jaxlib.xla_extension.DeviceArray.copy` method, which will create a copy of
|
||||
the buffer on the same device as the original array. This will work correctly within
|
||||
traced/JIT-compiled code, though copy operations may be elided by the compiler
|
||||
in this context.
|
||||
|
||||
When the built-in :mod:`pickle` module encounters a :class:`~jax.numpy.DeviceArray`,
|
||||
it will be serialized via a compact bit representation in a similar manner to pickled
|
||||
:class:`numpy.ndarray` objects. When unpickled, the result will be a new
|
||||
:class:`~jax.numpy.DeviceArray` object *on the default device.*
|
||||
This is because in general, pickling and unpickling may take place in different runtime
|
||||
environments, and there is no general way to map the device IDs of one runtime
|
||||
to the device IDs of another. If :mod:`pickle` is used in traced/JIT-compiled code,
|
||||
it will result in a :class:`~jax.errors.ConcretizationTypeError`.
|
||||
|
||||
Class Reference
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: jax.numpy.DeviceArray
|
||||
|
||||
.. autoclass:: jaxlib.xla_extension.DeviceArrayBase
|
||||
|
@ -21,6 +21,7 @@ import weakref
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src.config import config
|
||||
from jax._src import abstract_arrays
|
||||
@ -265,6 +266,14 @@ for device_array in [DeviceArray]:
|
||||
|
||||
setattr(device_array, "__array__", __array__)
|
||||
|
||||
def __reduce__(self):
|
||||
fun, args, arr_state = self._value.__reduce__()
|
||||
aval_state = {'weak_type': self.aval.weak_type,
|
||||
'named_shape': self.aval.named_shape}
|
||||
return (reconstruct_device_array, (fun, args, arr_state, aval_state))
|
||||
|
||||
setattr(device_array, "__reduce__", __reduce__)
|
||||
|
||||
setattr(device_array, "__str__", partialmethod(_forward_to_value, str))
|
||||
setattr(device_array, "__bool__", partialmethod(_forward_to_value, bool))
|
||||
setattr(device_array, "__nonzero__", partialmethod(_forward_to_value, bool))
|
||||
@ -280,10 +289,6 @@ for device_array in [DeviceArray]:
|
||||
del to_bytes
|
||||
setattr(device_array, "tolist", lambda self: self._value.tolist())
|
||||
|
||||
# pickle saves and loads just like an ndarray
|
||||
setattr(device_array, "__reduce__",
|
||||
partialmethod(_forward_to_value, operator.methodcaller("__reduce__")))
|
||||
|
||||
# explicitly set to be unhashable.
|
||||
setattr(device_array, "__hash__", None)
|
||||
|
||||
@ -298,7 +303,16 @@ for device_array in [DeviceArray]:
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
class DeletedBuffer: pass
|
||||
def reconstruct_device_array(fun, args, arr_state, aval_state):
|
||||
"""Method to reconstruct a device array from a serialized state."""
|
||||
np_value = fun(*args)
|
||||
np_value.__setstate__(arr_state)
|
||||
jnp_value = jax.device_put(np_value)
|
||||
jnp_value.aval = jnp_value.aval.update(**aval_state)
|
||||
return jnp_value
|
||||
|
||||
|
||||
class DeletedBuffer(object): pass
|
||||
deleted_buffer = DeletedBuffer()
|
||||
|
||||
|
||||
|
@ -4572,9 +4572,18 @@ def _operator_round(number, ndigits=None):
|
||||
# If `ndigits` is None, for a builtin float round(7.5) returns an integer.
|
||||
return out.astype(int) if ndigits is None else out
|
||||
|
||||
def _copy(self):
|
||||
return self.copy()
|
||||
|
||||
def _deepcopy(self, memo):
|
||||
del memo # unused
|
||||
return self.copy()
|
||||
|
||||
_operators = {
|
||||
"getitem": _rewriting_take,
|
||||
"setitem": _unimplemented_setitem,
|
||||
"copy": _copy,
|
||||
"deepcopy": _deepcopy,
|
||||
"neg": negative,
|
||||
"pos": positive,
|
||||
"eq": _defer_to_unrecognized_arg(equal),
|
||||
|
14
jax/core.py
14
jax/core.py
@ -603,6 +603,14 @@ class Tracer:
|
||||
def __oct__(self): return self.aval._oct(self)
|
||||
def __float__(self): return self.aval._float(self)
|
||||
def __complex__(self): return self.aval._complex(self)
|
||||
def __copy__(self): return self.aval._copy(self)
|
||||
def __deepcopy__(self, memo): return self.aval._deepcopy(self, memo)
|
||||
|
||||
# raises a useful error on attempts to pickle a Tracer.
|
||||
def __reduce__(self):
|
||||
raise ConcretizationTypeError(
|
||||
self, ("The error occurred in the __reduce__ method, which may "
|
||||
"indicate an attempt to serialize/pickle a traced value."))
|
||||
|
||||
# raises the better error message from ShapedArray
|
||||
def __setitem__(self, idx, val): return self.aval._setitem(self, idx, val)
|
||||
@ -650,12 +658,6 @@ class Tracer:
|
||||
except AttributeError:
|
||||
return ()
|
||||
|
||||
def __copy__(self):
|
||||
return self
|
||||
|
||||
def __deepcopy__(self, unused_memo):
|
||||
return self
|
||||
|
||||
def _origin_msg(self) -> str:
|
||||
return ""
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import functools
|
||||
from functools import partial
|
||||
import inspect
|
||||
@ -3828,10 +3829,15 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
{"testcase_name": f"_dtype={np.dtype(dtype)}_func={func}",
|
||||
"dtype": dtype, "func": func}
|
||||
for dtype in all_dtypes
|
||||
for func in ["array", "copy"]))
|
||||
for func in ["array", "copy", "copy.copy", "copy.deepcopy"]))
|
||||
def testArrayCopy(self, dtype, func):
|
||||
x = jnp.ones(10, dtype=dtype)
|
||||
copy_func = getattr(jnp, func)
|
||||
if func == "copy.deepcopy":
|
||||
copy_func = copy.deepcopy
|
||||
elif func == "copy.copy":
|
||||
copy_func = copy.copy
|
||||
else:
|
||||
copy_func = getattr(jnp, func)
|
||||
|
||||
x_view = jnp.asarray(x)
|
||||
x_view_jit = jax.jit(jnp.asarray)(x)
|
||||
|
@ -24,6 +24,7 @@ except ImportError:
|
||||
cloudpickle = None
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import numpy as jnp
|
||||
from jax.config import config
|
||||
from jax._src import test_util as jtu
|
||||
@ -73,5 +74,42 @@ class CloudpickleTest(jtu.JaxTestCase):
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
class PickleTest(jtu.JaxTestCase):
|
||||
|
||||
def testPickleOfDeviceArray(self):
|
||||
x = jnp.arange(10.0)
|
||||
s = pickle.dumps(x)
|
||||
y = pickle.loads(s)
|
||||
self.assertArraysEqual(x, y)
|
||||
self.assertIsInstance(y, type(x))
|
||||
self.assertEqual(x.aval, y.aval)
|
||||
|
||||
def testPickleOfDeviceArrayWeakType(self):
|
||||
x = jnp.array(4.0)
|
||||
self.assertEqual(x.aval.weak_type, True)
|
||||
s = pickle.dumps(x)
|
||||
y = pickle.loads(s)
|
||||
self.assertArraysEqual(x, y)
|
||||
self.assertIsInstance(y, type(x))
|
||||
self.assertEqual(x.aval, y.aval)
|
||||
|
||||
def testPickleX64(self):
|
||||
with jax.experimental.enable_x64():
|
||||
x = jnp.array(4.0, dtype='float64')
|
||||
s = pickle.dumps(x)
|
||||
|
||||
with jax.experimental.disable_x64():
|
||||
y = pickle.loads(s)
|
||||
|
||||
self.assertEqual(x.dtype, jnp.float64)
|
||||
self.assertArraysEqual(x, y, check_dtypes=False)
|
||||
self.assertEqual(y.dtype, jnp.float32)
|
||||
self.assertEqual(y.aval.dtype, jnp.float32)
|
||||
self.assertIsInstance(y, type(x))
|
||||
|
||||
def testPickleTracerError(self):
|
||||
with self.assertRaises(core.ConcretizationTypeError):
|
||||
jax.jit(pickle.dumps)(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user