Merge pull request #10659 from jakevdp:devicearray-pickle

PiperOrigin-RevId: 449995717
This commit is contained in:
jax authors 2022-05-20 08:59:07 -07:00
commit 87d2474cdf
7 changed files with 116 additions and 13 deletions

View File

@ -15,6 +15,16 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
that allows selection between an LU-decomposition based implementation and
an implementation based on QR decomposition.
* {func}`jax.numpy.linalg.qr` now supports `mode="raw"`.
* `pickle`, `copy.copy`, and `copy.deepcopy` now have more complete support when
used on jax arrays ({jax-issue}`#10659`). In particular:
- `pickle` and `deepcopy` previously returned `np.ndarray` objects when used
on a `DeviceArray`; now `DeviceArray` objects are returned. For `deepcopy`,
the copied array is on the same device as the original. For `pickle` the
deserialized array will be on the default device.
- Within function transformations (i.e. traced code), `deepcopy` and `copy`
previously were no-ops. Now they use the same mechanism as `DeviceArray.copy()`.
- Calling `pickle` on a traced array now results in an explicit
`ConcretizationTypeError`.
## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).

View File

@ -474,6 +474,30 @@ instantiate :class:`DeviceArray` objects manually, but rather will create them v
:mod:`jax.numpy` functions like :func:`~jax.numpy.array`, :func:`~jax.numpy.arange`,
:func:`~jax.numpy.linspace`, and others listed above.
Copying and Serialization
~~~~~~~~~~~~~~~~~~~~~~~~~
:class:`~jax.numpy.DeviceArray`` objects are designed to work seamlessly with Python
standard library tools where appropriate.
With the built-in :mod:`copy` module, when :func:`copy.copy` or :func:`copy.deepcopy`
encounder a :class:`~jax.numpy.DeviceArray`, it is equivalent to calling the
:meth:`~jaxlib.xla_extension.DeviceArray.copy` method, which will create a copy of
the buffer on the same device as the original array. This will work correctly within
traced/JIT-compiled code, though copy operations may be elided by the compiler
in this context.
When the built-in :mod:`pickle` module encounters a :class:`~jax.numpy.DeviceArray`,
it will be serialized via a compact bit representation in a similar manner to pickled
:class:`numpy.ndarray` objects. When unpickled, the result will be a new
:class:`~jax.numpy.DeviceArray` object *on the default device.*
This is because in general, pickling and unpickling may take place in different runtime
environments, and there is no general way to map the device IDs of one runtime
to the device IDs of another. If :mod:`pickle` is used in traced/JIT-compiled code,
it will result in a :class:`~jax.errors.ConcretizationTypeError`.
Class Reference
~~~~~~~~~~~~~~~
.. autoclass:: jax.numpy.DeviceArray
.. autoclass:: jaxlib.xla_extension.DeviceArrayBase

View File

@ -21,6 +21,7 @@ import weakref
import numpy as np
import jax
from jax import core
from jax._src.config import config
from jax._src import abstract_arrays
@ -265,6 +266,14 @@ for device_array in [DeviceArray]:
setattr(device_array, "__array__", __array__)
def __reduce__(self):
fun, args, arr_state = self._value.__reduce__()
aval_state = {'weak_type': self.aval.weak_type,
'named_shape': self.aval.named_shape}
return (reconstruct_device_array, (fun, args, arr_state, aval_state))
setattr(device_array, "__reduce__", __reduce__)
setattr(device_array, "__str__", partialmethod(_forward_to_value, str))
setattr(device_array, "__bool__", partialmethod(_forward_to_value, bool))
setattr(device_array, "__nonzero__", partialmethod(_forward_to_value, bool))
@ -280,10 +289,6 @@ for device_array in [DeviceArray]:
del to_bytes
setattr(device_array, "tolist", lambda self: self._value.tolist())
# pickle saves and loads just like an ndarray
setattr(device_array, "__reduce__",
partialmethod(_forward_to_value, operator.methodcaller("__reduce__")))
# explicitly set to be unhashable.
setattr(device_array, "__hash__", None)
@ -298,7 +303,16 @@ for device_array in [DeviceArray]:
# pylint: enable=protected-access
class DeletedBuffer: pass
def reconstruct_device_array(fun, args, arr_state, aval_state):
"""Method to reconstruct a device array from a serialized state."""
np_value = fun(*args)
np_value.__setstate__(arr_state)
jnp_value = jax.device_put(np_value)
jnp_value.aval = jnp_value.aval.update(**aval_state)
return jnp_value
class DeletedBuffer(object): pass
deleted_buffer = DeletedBuffer()

View File

@ -4572,9 +4572,18 @@ def _operator_round(number, ndigits=None):
# If `ndigits` is None, for a builtin float round(7.5) returns an integer.
return out.astype(int) if ndigits is None else out
def _copy(self):
return self.copy()
def _deepcopy(self, memo):
del memo # unused
return self.copy()
_operators = {
"getitem": _rewriting_take,
"setitem": _unimplemented_setitem,
"copy": _copy,
"deepcopy": _deepcopy,
"neg": negative,
"pos": positive,
"eq": _defer_to_unrecognized_arg(equal),

View File

@ -603,6 +603,14 @@ class Tracer:
def __oct__(self): return self.aval._oct(self)
def __float__(self): return self.aval._float(self)
def __complex__(self): return self.aval._complex(self)
def __copy__(self): return self.aval._copy(self)
def __deepcopy__(self, memo): return self.aval._deepcopy(self, memo)
# raises a useful error on attempts to pickle a Tracer.
def __reduce__(self):
raise ConcretizationTypeError(
self, ("The error occurred in the __reduce__ method, which may "
"indicate an attempt to serialize/pickle a traced value."))
# raises the better error message from ShapedArray
def __setitem__(self, idx, val): return self.aval._setitem(self, idx, val)
@ -650,12 +658,6 @@ class Tracer:
except AttributeError:
return ()
def __copy__(self):
return self
def __deepcopy__(self, unused_memo):
return self
def _origin_msg(self) -> str:
return ""

View File

@ -14,6 +14,7 @@
import collections
import copy
import functools
from functools import partial
import inspect
@ -3828,10 +3829,15 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
{"testcase_name": f"_dtype={np.dtype(dtype)}_func={func}",
"dtype": dtype, "func": func}
for dtype in all_dtypes
for func in ["array", "copy"]))
for func in ["array", "copy", "copy.copy", "copy.deepcopy"]))
def testArrayCopy(self, dtype, func):
x = jnp.ones(10, dtype=dtype)
copy_func = getattr(jnp, func)
if func == "copy.deepcopy":
copy_func = copy.deepcopy
elif func == "copy.copy":
copy_func = copy.copy
else:
copy_func = getattr(jnp, func)
x_view = jnp.asarray(x)
x_view_jit = jax.jit(jnp.asarray)(x)

View File

@ -24,6 +24,7 @@ except ImportError:
cloudpickle = None
import jax
from jax import core
from jax import numpy as jnp
from jax.config import config
from jax._src import test_util as jtu
@ -73,5 +74,42 @@ class CloudpickleTest(jtu.JaxTestCase):
self.assertEqual(expected, actual)
class PickleTest(jtu.JaxTestCase):
def testPickleOfDeviceArray(self):
x = jnp.arange(10.0)
s = pickle.dumps(x)
y = pickle.loads(s)
self.assertArraysEqual(x, y)
self.assertIsInstance(y, type(x))
self.assertEqual(x.aval, y.aval)
def testPickleOfDeviceArrayWeakType(self):
x = jnp.array(4.0)
self.assertEqual(x.aval.weak_type, True)
s = pickle.dumps(x)
y = pickle.loads(s)
self.assertArraysEqual(x, y)
self.assertIsInstance(y, type(x))
self.assertEqual(x.aval, y.aval)
def testPickleX64(self):
with jax.experimental.enable_x64():
x = jnp.array(4.0, dtype='float64')
s = pickle.dumps(x)
with jax.experimental.disable_x64():
y = pickle.loads(s)
self.assertEqual(x.dtype, jnp.float64)
self.assertArraysEqual(x, y, check_dtypes=False)
self.assertEqual(y.dtype, jnp.float32)
self.assertEqual(y.aval.dtype, jnp.float32)
self.assertIsInstance(y, type(x))
def testPickleTracerError(self):
with self.assertRaises(core.ConcretizationTypeError):
jax.jit(pickle.dumps)(0)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())