Merge pull request #10584 from mattjj:remove-jnp-array-handling-raw-buffers

PiperOrigin-RevId: 448720084
This commit is contained in:
jax authors 2022-05-14 14:37:46 -07:00
commit f26133cc1e
2 changed files with 19 additions and 17 deletions

View File

@ -1806,15 +1806,16 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
lax_internal._check_user_dtype_supported(dtype, "array")
# Here we make a judgment call: we only return a weakly-typed array when the
# input object itself is weakly typed. That ensures asarray(x) is a no-op whenever
# x is weak, but avoids introducing weak types with something like array([1, 2, 3])
# input object itself is weakly typed. That ensures asarray(x) is a no-op
# whenever x is weak, but avoids introducing weak types with something like
# array([1, 2, 3])
weak_type = dtype is None and dtypes.is_weakly_typed(object)
# For Python scalar literals, call coerce_to_array to catch any overflow errors.
# We don't use dtypes.is_python_scalar because we don't want this triggering for
# traced values. We do this here because it matters whether or not dtype is None.
# We don't assign the result because we want the raw object to be used for type
# inference below.
# For Python scalar literals, call coerce_to_array to catch any overflow
# errors. We don't use dtypes.is_python_scalar because we don't want this
# triggering for traced values. We do this here because it matters whether or
# not dtype is None. We don't assign the result because we want the raw object
# to be used for type inference below.
if isinstance(object, (bool, int, float, complex)):
_ = dtypes.coerce_to_array(object, dtype)
@ -1838,17 +1839,13 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
ndarray_types = (device_array.DeviceArray, core.Tracer)
if not _any(isinstance(leaf, ndarray_types) for leaf in leaves):
# TODO(jakevdp): falling back to numpy here fails to overflow for lists containing
# large integers; see discussion in https://github.com/google/jax/pull/6047.
# More correct would be to call coerce_to_array on each leaf, but this may have
# performance implications.
# TODO(jakevdp): falling back to numpy here fails to overflow for lists
# containing large integers; see discussion in
# https://github.com/google/jax/pull/6047. More correct would be to call
# coerce_to_array on each leaf, but this may have performance implications.
out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False)
elif isinstance(object, ndarray_types):
if object.aval is None:
# object is a raw buffer; convert to device array on its current device.
aval = ShapedArray(object.xla_shape().dimensions(), object.dtype,
weak_type=bool(getattr(object, "weak_type", False)))
object = device_array.make_device_array(aval, object.device(), object)
assert object.aval is not None
out = _array_copy(object) if copy else object
elif isinstance(object, (list, tuple)):
if object:

View File

@ -117,7 +117,12 @@ class PythonPmapTest(jtu.JaxTestCase):
def testDeviceBufferToArray(self):
sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2)))
buf = sda.device_buffers[-1]
# Changed in https://github.com/google/jax/pull/10584 not to access
# sda.device_buffers, which isn't supported, and instead ensure fast slices
# of the arrays returned by pmap are set up correctly.
# buf = sda.device_buffers[-1]
buf = sda[-1]
view = jnp.array(buf, copy=False)
self.assertArraysEqual(sda[-1], view)