mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #10584 from mattjj:remove-jnp-array-handling-raw-buffers
PiperOrigin-RevId: 448720084
This commit is contained in:
commit
f26133cc1e
@ -1806,15 +1806,16 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
|
||||
lax_internal._check_user_dtype_supported(dtype, "array")
|
||||
|
||||
# Here we make a judgment call: we only return a weakly-typed array when the
|
||||
# input object itself is weakly typed. That ensures asarray(x) is a no-op whenever
|
||||
# x is weak, but avoids introducing weak types with something like array([1, 2, 3])
|
||||
# input object itself is weakly typed. That ensures asarray(x) is a no-op
|
||||
# whenever x is weak, but avoids introducing weak types with something like
|
||||
# array([1, 2, 3])
|
||||
weak_type = dtype is None and dtypes.is_weakly_typed(object)
|
||||
|
||||
# For Python scalar literals, call coerce_to_array to catch any overflow errors.
|
||||
# We don't use dtypes.is_python_scalar because we don't want this triggering for
|
||||
# traced values. We do this here because it matters whether or not dtype is None.
|
||||
# We don't assign the result because we want the raw object to be used for type
|
||||
# inference below.
|
||||
# For Python scalar literals, call coerce_to_array to catch any overflow
|
||||
# errors. We don't use dtypes.is_python_scalar because we don't want this
|
||||
# triggering for traced values. We do this here because it matters whether or
|
||||
# not dtype is None. We don't assign the result because we want the raw object
|
||||
# to be used for type inference below.
|
||||
if isinstance(object, (bool, int, float, complex)):
|
||||
_ = dtypes.coerce_to_array(object, dtype)
|
||||
|
||||
@ -1838,17 +1839,13 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
|
||||
ndarray_types = (device_array.DeviceArray, core.Tracer)
|
||||
|
||||
if not _any(isinstance(leaf, ndarray_types) for leaf in leaves):
|
||||
# TODO(jakevdp): falling back to numpy here fails to overflow for lists containing
|
||||
# large integers; see discussion in https://github.com/google/jax/pull/6047.
|
||||
# More correct would be to call coerce_to_array on each leaf, but this may have
|
||||
# performance implications.
|
||||
# TODO(jakevdp): falling back to numpy here fails to overflow for lists
|
||||
# containing large integers; see discussion in
|
||||
# https://github.com/google/jax/pull/6047. More correct would be to call
|
||||
# coerce_to_array on each leaf, but this may have performance implications.
|
||||
out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False)
|
||||
elif isinstance(object, ndarray_types):
|
||||
if object.aval is None:
|
||||
# object is a raw buffer; convert to device array on its current device.
|
||||
aval = ShapedArray(object.xla_shape().dimensions(), object.dtype,
|
||||
weak_type=bool(getattr(object, "weak_type", False)))
|
||||
object = device_array.make_device_array(aval, object.device(), object)
|
||||
assert object.aval is not None
|
||||
out = _array_copy(object) if copy else object
|
||||
elif isinstance(object, (list, tuple)):
|
||||
if object:
|
||||
|
@ -117,7 +117,12 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
def testDeviceBufferToArray(self):
|
||||
sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2)))
|
||||
buf = sda.device_buffers[-1]
|
||||
|
||||
# Changed in https://github.com/google/jax/pull/10584 not to access
|
||||
# sda.device_buffers, which isn't supported, and instead ensure fast slices
|
||||
# of the arrays returned by pmap are set up correctly.
|
||||
# buf = sda.device_buffers[-1]
|
||||
buf = sda[-1]
|
||||
|
||||
view = jnp.array(buf, copy=False)
|
||||
self.assertArraysEqual(sda[-1], view)
|
||||
|
Loading…
x
Reference in New Issue
Block a user