Replace references to DeviceArray with Array.

A number of stale references are lurking in our documentation.
This commit is contained in:
Peter Hawkins 2023-08-18 16:50:36 -04:00
parent 97af33c4d1
commit 2c32660a8f
40 changed files with 161 additions and 173 deletions

View File

@ -251,9 +251,9 @@
"colab_type": "text"
},
"source": [
"A `ShardedDeviceArray` is effectively an `ndarray` subclass, but it's stored in pieces spread across the memory of multiple devices. Results from `pmap` functions are left sharded in device memory so that they can be operated on by subsequent `pmap` functions without moving data around, at least in some cases. But these results logically appear just like a single array.\n",
"A sharded `Array` is effectively an `ndarray` subclass, but it's stored in pieces spread across the memory of multiple devices. Results from `pmap` functions are left sharded in device memory so that they can be operated on by subsequent `pmap` functions without moving data around, at least in some cases. But these results logically appear just like a single array.\n",
"\n",
"When you call a non-`pmap` function on a `ShardedDeviceArray`, like a standard `jax.numpy` function, communication happens behind the scenes to bring the values to one device (or back to the host in the case of the matplotlib function above):"
"When you call a non-`pmap` function on an `Array`, like a standard `jax.numpy` function, communication happens behind the scenes to bring the values to one device (or back to the host in the case of the matplotlib function above):"
]
},
{

View File

@ -2344,8 +2344,8 @@
"One piece missing is device memory persistence for arrays. That is, we've\n",
"defined `handle_result` to transfer results back to CPU memory as NumPy\n",
"arrays, but it's often preferable to avoid transferring results just to\n",
"transfer them back for the next operation. We can do that by introducing a\n",
"`DeviceArray` class, which can wrap XLA buffers and otherwise duck-type\n",
"transfer them back for the next operation. We can do that by introducing an\n",
"`Array` class, which can wrap XLA buffers and otherwise duck-type\n",
"`numpy.ndarray`s:"
]
},
@ -2356,9 +2356,9 @@
"outputs": [],
"source": [
"def handle_result(aval: ShapedArray, buf): # noqa: F811\n",
" return DeviceArray(aval, buf)\n",
" return Array(aval, buf)\n",
"\n",
"class DeviceArray:\n",
"class Array:\n",
" buf: Any\n",
" aval: ShapedArray\n",
"\n",
@ -2381,9 +2381,9 @@
" _rmul = staticmethod(mul)\n",
" _gt = staticmethod(greater)\n",
" _lt = staticmethod(less)\n",
"input_handlers[DeviceArray] = lambda x: x.buf\n",
"input_handlers[Array] = lambda x: x.buf\n",
"\n",
"jax_types.add(DeviceArray)"
"jax_types.add(Array)"
]
},
{

View File

@ -1822,15 +1822,15 @@ print(ys)
One piece missing is device memory persistence for arrays. That is, we've
defined `handle_result` to transfer results back to CPU memory as NumPy
arrays, but it's often preferable to avoid transferring results just to
transfer them back for the next operation. We can do that by introducing a
`DeviceArray` class, which can wrap XLA buffers and otherwise duck-type
transfer them back for the next operation. We can do that by introducing an
`Array` class, which can wrap XLA buffers and otherwise duck-type
`numpy.ndarray`s:
```{code-cell}
def handle_result(aval: ShapedArray, buf): # noqa: F811
return DeviceArray(aval, buf)
return Array(aval, buf)
class DeviceArray:
class Array:
buf: Any
aval: ShapedArray
@ -1853,9 +1853,9 @@ class DeviceArray:
_rmul = staticmethod(mul)
_gt = staticmethod(greater)
_lt = staticmethod(less)
input_handlers[DeviceArray] = lambda x: x.buf
input_handlers[Array] = lambda x: x.buf
jax_types.add(DeviceArray)
jax_types.add(Array)
```
```{code-cell}

View File

@ -1813,15 +1813,15 @@ print(ys)
# One piece missing is device memory persistence for arrays. That is, we've
# defined `handle_result` to transfer results back to CPU memory as NumPy
# arrays, but it's often preferable to avoid transferring results just to
# transfer them back for the next operation. We can do that by introducing a
# `DeviceArray` class, which can wrap XLA buffers and otherwise duck-type
# transfer them back for the next operation. We can do that by introducing an
# `Array` class, which can wrap XLA buffers and otherwise duck-type
# `numpy.ndarray`s:
# +
def handle_result(aval: ShapedArray, buf): # noqa: F811
return DeviceArray(aval, buf)
return Array(aval, buf)
class DeviceArray:
class Array:
buf: Any
aval: ShapedArray
@ -1844,9 +1844,9 @@ class DeviceArray:
_rmul = staticmethod(mul)
_gt = staticmethod(greater)
_lt = staticmethod(less)
input_handlers[DeviceArray] = lambda x: x.buf
input_handlers[Array] = lambda x: x.buf
jax_types.add(DeviceArray)
jax_types.add(Array)
# +

View File

@ -68,7 +68,7 @@
"source": [
"So far, everything is just like NumPy. A big appeal of JAX is that you don't need to learn a new API. Many common NumPy programs would run just as well in JAX if you substitute `np` for `jnp`. However, there are some important differences which we touch on at the end of this section.\n",
"\n",
"You can notice the first difference if you check the type of `x`. It is a variable of type `DeviceArray`, which is the way JAX represents arrays."
"You can notice the first difference if you check the type of `x`. It is a variable of type `Array`, which is the way JAX represents arrays."
]
},
{
@ -81,7 +81,7 @@
{
"data": {
"text/plain": [
"DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)"
"Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)"
]
},
"execution_count": 2,
@ -277,8 +277,8 @@
{
"data": {
"text/plain": [
"(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n",
" DeviceArray([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))"
"(Array([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n",
" Array([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))"
]
},
"execution_count": 7,
@ -338,8 +338,8 @@
{
"data": {
"text/plain": [
"(DeviceArray(0.03999995, dtype=float32),\n",
" DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))"
"(Array(0.03999995, dtype=float32),\n",
" Array([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))"
]
},
"execution_count": 8,
@ -395,7 +395,7 @@
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFilteredStackTrace\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-9-7433a86e7375>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquared_error_with_aux\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mFilteredStackTrace\u001b[0m: TypeError: Gradient only defined for scalar-output functions. Output was (DeviceArray(0.03999995, dtype=float32), DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32)).\n\nThe stack trace above excludes JAX-internal frames."
"\u001b[0;31mFilteredStackTrace\u001b[0m: TypeError: Gradient only defined for scalar-output functions. Output was (Array(0.03999995, dtype=float32), Array([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32)).\n\nThe stack trace above excludes JAX-internal frames."
]
}
],
@ -425,8 +425,8 @@
{
"data": {
"text/plain": [
"(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n",
" DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32))"
"(Array([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n",
" Array([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32))"
]
},
"execution_count": 10,
@ -530,7 +530,7 @@
"\u001b[0;32m<ipython-input-12-709e2d7ddd3f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0min_place_modify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Raises error when we cast input to jnp.ndarray\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-11-fce65eb843c7>\u001b[0m in \u001b[0;36min_place_modify\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0min_place_modify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m123\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py\u001b[0m in \u001b[0;36m_unimplemented_setitem\u001b[0;34m(self, i, x)\u001b[0m\n\u001b[1;32m 6594\u001b[0m \u001b[0;34m\"or another .at[] method: \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6595\u001b[0m \"https://jax.readthedocs.io/en/latest/jax.ops.html\")\n\u001b[0;32m-> 6596\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6597\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6598\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_operator_round\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumber\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mndigits\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html"
"\u001b[0;31mTypeError\u001b[0m: '<class 'jaxlib.xla_extension.Array'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html"
]
}
],
@ -557,7 +557,7 @@
{
"data": {
"text/plain": [
"DeviceArray([123, 2, 3], dtype=int32)"
"Array([123, 2, 3], dtype=int32)"
]
},
"execution_count": 13,
@ -594,7 +594,7 @@
{
"data": {
"text/plain": [
"DeviceArray([1, 2, 3], dtype=int32)"
"Array([1, 2, 3], dtype=int32)"
]
},
"execution_count": 14,

View File

@ -45,7 +45,7 @@ print(x)
So far, everything is just like NumPy. A big appeal of JAX is that you don't need to learn a new API. Many common NumPy programs would run just as well in JAX if you substitute `np` for `jnp`. However, there are some important differences which we touch on at the end of this section.
You can notice the first difference if you check the type of `x`. It is a variable of type `DeviceArray`, which is the way JAX represents arrays.
You can notice the first difference if you check the type of `x`. It is a variable of type `Array`, which is the way JAX represents arrays.
```{code-cell} ipython3
:id: 3fLtgPUAn7mi

View File

@ -401,7 +401,7 @@
{
"data": {
"text/plain": [
"DeviceArray(30, dtype=int32, weak_type=True)"
"Array(30, dtype=int32, weak_type=True)"
]
},
"execution_count": 8,

View File

@ -37,7 +37,7 @@
{
"data": {
"text/plain": [
"DeviceArray([11., 20., 29.], dtype=float32)"
"Array([11., 20., 29.], dtype=float32)"
]
},
"execution_count": 1,
@ -104,7 +104,7 @@
{
"data": {
"text/plain": [
"DeviceArray([[11., 20., 29.],\n",
"Array([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
@ -149,7 +149,7 @@
{
"data": {
"text/plain": [
"DeviceArray([[11., 20., 29.],\n",
"Array([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
@ -201,7 +201,7 @@
{
"data": {
"text/plain": [
"DeviceArray([[11., 20., 29.],\n",
"Array([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
@ -240,7 +240,7 @@
{
"data": {
"text/plain": [
"DeviceArray([[11., 11.],\n",
"Array([[11., 11.],\n",
" [20., 20.],\n",
" [29., 29.]], dtype=float32)"
]
@ -281,7 +281,7 @@
{
"data": {
"text/plain": [
"DeviceArray([[11., 20., 29.],\n",
"Array([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
@ -320,7 +320,7 @@
{
"data": {
"text/plain": [
"DeviceArray([[11., 20., 29.],\n",
"Array([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},

View File

@ -175,9 +175,9 @@
{
"data": {
"text/plain": [
"DeviceArray([[2., 0., 0.],\n",
" [0., 2., 0.],\n",
" [0., 0., 2.]], dtype=float32)"
"Array([[2., 0., 0.],\n",
" [0., 2., 0.],\n",
" [0., 0., 2.]], dtype=float32)"
]
},
"execution_count": 6,
@ -312,7 +312,7 @@
{
"data": {
"text/plain": [
"DeviceArray([ 2.4, -2.4, 2.4], dtype=float32)"
"Array([ 2.4, -2.4, 2.4], dtype=float32)"
]
},
"execution_count": 9,
@ -356,7 +356,7 @@
{
"data": {
"text/plain": [
"DeviceArray([-2.4, -4.8, 2.4], dtype=float32)"
"Array([-2.4, -4.8, 2.4], dtype=float32)"
]
},
"execution_count": 10,
@ -459,8 +459,8 @@
{
"data": {
"text/plain": [
"DeviceArray([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
"Array([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
]
},
"execution_count": 12,
@ -503,7 +503,7 @@
{
"data": {
"text/plain": [
"DeviceArray([-2.4, -4.8, 2.4], dtype=float32)"
"Array([-2.4, -4.8, 2.4], dtype=float32)"
]
},
"execution_count": 13,
@ -548,8 +548,8 @@
{
"data": {
"text/plain": [
"DeviceArray([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
"Array([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
]
},
"execution_count": 14,
@ -586,8 +586,8 @@
{
"data": {
"text/plain": [
"DeviceArray([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
"Array([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
]
},
"execution_count": 15,
@ -623,8 +623,8 @@
{
"data": {
"text/plain": [
"DeviceArray([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
"Array([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
]
},
"execution_count": 16,

View File

@ -347,7 +347,7 @@
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html\n"
"\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html\n"
]
}
],
@ -587,7 +587,7 @@
{
"data": {
"text/plain": [
"DeviceArray(9, dtype=int32)"
"Array(9, dtype=int32)"
]
},
"execution_count": 14,
@ -754,7 +754,7 @@
{
"data": {
"text/plain": [
"DeviceArray(45, dtype=int32)"
"Array(45, dtype=int32)"
]
},
"execution_count": 17,
@ -848,7 +848,7 @@
{
"data": {
"text/plain": [
"DeviceArray(45, dtype=int32)"
"Array(45, dtype=int32)"
]
},
"execution_count": 19,
@ -1020,7 +1020,7 @@
{
"data": {
"text/plain": [
"DeviceArray([0, 0], dtype=uint32)"
"Array([0, 0], dtype=uint32)"
]
},
"execution_count": 23,
@ -1408,7 +1408,7 @@
{
"data": {
"text/plain": [
"DeviceArray(5., dtype=float32)"
"Array(5., dtype=float32)"
]
},
"execution_count": 33,
@ -1553,7 +1553,7 @@
{
"data": {
"text/plain": [
"DeviceArray(4, dtype=int32, weak_type=True)"
"Array(4, dtype=int32, weak_type=True)"
]
},
"execution_count": 37,
@ -1616,7 +1616,7 @@
{
"data": {
"text/plain": [
"DeviceArray([-1.], dtype=float32)"
"Array([-1.], dtype=float32)"
]
},
"execution_count": 38,
@ -1689,7 +1689,7 @@
{
"data": {
"text/plain": [
"DeviceArray(10, dtype=int32, weak_type=True)"
"Array(10, dtype=int32, weak_type=True)"
]
},
"execution_count": 39,
@ -1733,7 +1733,7 @@
{
"data": {
"text/plain": [
"DeviceArray(45, dtype=int32, weak_type=True)"
"Array(45, dtype=int32, weak_type=True)"
]
},
"execution_count": 40,
@ -2000,7 +2000,7 @@
" 104 if np.any(np.isnan(py_val)):\n",
"--> 105 raise FloatingPointError(\"invalid value\")\n",
" 106 else:\n",
" 107 return DeviceArray(device_buffer, *result_shape)\n",
" 107 return Array(device_buffer, *result_shape)\n",
"\n",
"FloatingPointError: invalid value\n",
"```"
@ -2222,7 +2222,7 @@
" array([254, 255, 0, 1], dtype=uint8)\n",
"\n",
" >>> jnp.arange(254.0, 258.0).astype('uint8')\n",
" DeviceArray([254, 255, 255, 255], dtype=uint8)\n",
" Array([254, 255, 255, 255], dtype=uint8)\n",
" ```\n",
" This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n",
"\n",

View File

@ -987,7 +987,7 @@ FloatingPointError Traceback (most recent call last)
104 if np.any(np.isnan(py_val)):
--> 105 raise FloatingPointError("invalid value")
106 else:
107 return DeviceArray(device_buffer, *result_shape)
107 return Array(device_buffer, *result_shape)
FloatingPointError: invalid value
```
@ -1142,7 +1142,7 @@ Many such cases are discussed in detail in the sections above; here we list seve
array([254, 255, 0, 1], dtype=uint8)
>>> jnp.arange(254.0, 258.0).astype('uint8')
DeviceArray([254, 255, 255, 255], dtype=uint8)
Array([254, 255, 255, 255], dtype=uint8)
```
This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.

View File

@ -234,7 +234,7 @@
{
"data": {
"text/plain": [
"DeviceArray(3.0485873, dtype=float32)"
"Array(3.0485873, dtype=float32)"
]
},
"execution_count": 8,
@ -1578,7 +1578,7 @@
{
"data": {
"text/plain": [
"DeviceArray(-0.14112, dtype=float32)"
"Array(-0.14112, dtype=float32)"
]
},
"execution_count": 50,
@ -1901,7 +1901,7 @@
"output_type": "stream",
"text": [
"called f_bwd!\n",
"(DeviceArray(-0.9899925, dtype=float32),)\n"
"(Array(-0.9899925, dtype=float32),)\n"
]
}
],
@ -2013,9 +2013,9 @@
"> <ipython-input-113-b19a2dc1abf7>(12)debug_bwd()\n",
"-> return g\n",
"(Pdb) p x\n",
"DeviceArray(9., dtype=float32)\n",
"Array(9., dtype=float32)\n",
"(Pdb) p g\n",
"DeviceArray(-0.91113025, dtype=float32)\n",
"Array(-0.91113025, dtype=float32)\n",
"(Pdb) q\n",
"```"
]
@ -2085,7 +2085,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"{'a': 1.0, 'b': (DeviceArray(0.841471, dtype=float32), DeviceArray(-0.4161468, dtype=float32))}\n"
"{'a': 1.0, 'b': (Array(0.841471, dtype=float32), Array(-0.4161468, dtype=float32))}\n"
]
}
],
@ -2107,7 +2107,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Point(x=DeviceArray(2.5403023, dtype=float32), y=array(0., dtype=float32))\n"
"Point(x=Array(2.5403023, dtype=float32), y=array(0., dtype=float32))\n"
]
}
],
@ -2166,7 +2166,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"{'a': 1.0, 'b': (DeviceArray(0.841471, dtype=float32), DeviceArray(-0.4161468, dtype=float32))}\n"
"{'a': 1.0, 'b': (Array(0.841471, dtype=float32), Array(-0.4161468, dtype=float32))}\n"
]
}
],
@ -2188,7 +2188,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Point(x=DeviceArray(2.5403023, dtype=float32), y=DeviceArray(-0., dtype=float32))\n"
"Point(x=Array(2.5403023, dtype=float32), y=Array(-0., dtype=float32))\n"
]
}
],

View File

@ -1036,9 +1036,9 @@ jax.grad(foo)(3.)
> <ipython-input-113-b19a2dc1abf7>(12)debug_bwd()
-> return g
(Pdb) p x
DeviceArray(9., dtype=float32)
Array(9., dtype=float32)
(Pdb) p g
DeviceArray(-0.91113025, dtype=float32)
Array(-0.91113025, dtype=float32)
(Pdb) q
```

View File

@ -236,7 +236,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"{'W': DeviceArray([-0.16965576, -0.8774645 , -1.4901344 ], dtype=float32), 'b': DeviceArray(-0.29227236, dtype=float32)}\n"
"{'W': Array([-0.16965576, -0.8774645 , -1.4901344 ], dtype=float32), 'b': Array(-0.29227236, dtype=float32)}\n"
]
}
],
@ -1204,7 +1204,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(DeviceArray(3.1415927, dtype=float32),)\n"
"(Array(3.1415927, dtype=float32),)\n"
]
}
],
@ -1470,7 +1470,7 @@
{
"data": {
"text/plain": [
"DeviceArray(6.-8.j, dtype=complex64)"
"Array(6.-8.j, dtype=complex64)"
]
},
"execution_count": 31,
@ -1511,7 +1511,7 @@
{
"data": {
"text/plain": [
"DeviceArray(-27.034945-3.8511531j, dtype=complex64)"
"Array(-27.034945-3.8511531j, dtype=complex64)"
]
},
"execution_count": 32,
@ -1549,7 +1549,7 @@
{
"data": {
"text/plain": [
"DeviceArray(1.-0.j, dtype=complex64)"
"Array(1.-0.j, dtype=complex64)"
]
},
"execution_count": 33,
@ -1602,12 +1602,12 @@
{
"data": {
"text/plain": [
"DeviceArray([[-0.75342447 +0.j , -3.0509021 -10.940544j ,\n",
" 5.989684 +3.5422976j],\n",
" [-3.0509021 +10.940544j , -8.904487 +0.j ,\n",
" -5.1351547 -6.5593696j],\n",
" [ 5.989684 -3.5422976j, -5.1351547 +6.5593696j,\n",
" 0.01320434 +0.j ]], dtype=complex64)"
"Array([[-0.75342447 +0.j , -3.0509021 -10.940544j ,\n",
" 5.989684 +3.5422976j],\n",
" [-3.0509021 +10.940544j , -8.904487 +0.j ,\n",
" -5.1351547 -6.5593696j],\n",
" [ 5.989684 -3.5422976j, -5.1351547 +6.5593696j,\n",
" 0.01320434 +0.j ]], dtype=complex64)"
]
},
"execution_count": 34,

View File

@ -1071,7 +1071,7 @@
{
"data": {
"text/plain": [
"DeviceArray(-0.4003078, dtype=float32, weak_type=True)"
"Array(-0.4003078, dtype=float32, weak_type=True)"
]
},
"execution_count": 8,

View File

@ -170,7 +170,7 @@
{
"data": {
"text/plain": [
"jax.interpreters.xla._DeviceArray"
"jaxlib.xla_extension.ArrayImpl"
]
},
"execution_count": 5,
@ -248,7 +248,7 @@
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-7-6b90817377fe>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# JAX: immutable arrays\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m: '<class 'jax.interpreters.xla._DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?"
"\u001b[0;31mTypeError\u001b[0m: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?"
]
}
],
@ -327,7 +327,7 @@
{
"data": {
"text/plain": [
"DeviceArray(2., dtype=float32)"
"Array(2., dtype=float32)"
]
},
"execution_count": 9,
@ -390,7 +390,7 @@
{
"data": {
"text/plain": [
"DeviceArray(2., dtype=float32)"
"Array(2., dtype=float32)"
]
},
"execution_count": 11,
@ -426,7 +426,7 @@
{
"data": {
"text/plain": [
"DeviceArray([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)"
"Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)"
]
},
"execution_count": 12,
@ -462,7 +462,7 @@
{
"data": {
"text/plain": [
"DeviceArray([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)"
"Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)"
]
},
"execution_count": 13,
@ -638,7 +638,7 @@
{
"data": {
"text/plain": [
"DeviceArray([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)"
"Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)"
]
},
"execution_count": 18,
@ -739,7 +739,7 @@
{
"data": {
"text/plain": [
"DeviceArray([0.25773212, 5.3623195 , 5.4032435 ], dtype=float32)"
"Array([0.25773212, 5.3623195 , 5.4032435 ], dtype=float32)"
]
},
"execution_count": 20,
@ -788,7 +788,7 @@
{
"data": {
"text/plain": [
"DeviceArray([1.4344584, 4.3004413, 7.9897013], dtype=float32)"
"Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)"
]
},
"execution_count": 21,
@ -908,7 +908,7 @@
{
"data": {
"text/plain": [
"DeviceArray(-1, dtype=int32)"
"Array(-1, dtype=int32)"
]
},
"execution_count": 24,
@ -948,7 +948,7 @@
{
"data": {
"text/plain": [
"DeviceArray(1, dtype=int32)"
"Array(1, dtype=int32)"
]
},
"execution_count": 25,
@ -1086,7 +1086,7 @@
{
"data": {
"text/plain": [
"DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)"
"Array([1., 1., 1., 1., 1., 1.], dtype=float32)"
]
},
"execution_count": 28,

View File

@ -145,7 +145,7 @@
{
"data": {
"text/plain": [
"DeviceArray(-213.23558, dtype=float32)"
"Array(-213.23558, dtype=float32)"
]
},
"execution_count": 13,
@ -229,9 +229,9 @@
{
"data": {
"text/plain": [
"DeviceArray([-147.84033203, -207.02204895, -109.26074982, -243.80830383,\n",
" -163.02911377, -143.84848022, -160.28771973, -113.77169037,\n",
" -126.60544586, -190.81988525], dtype=float32)"
"Array([-147.84033203, -207.02204895, -109.26074982, -243.80830383,\n",
" -163.02911377, -143.84848022, -160.28771973, -113.77169037,\n",
" -126.60544586, -190.81988525], dtype=float32)"
]
},
"execution_count": 16,
@ -270,9 +270,9 @@
{
"data": {
"text/plain": [
"DeviceArray([-147.84033203, -207.02204895, -109.26074982, -243.80830383,\n",
" -163.02911377, -143.84848022, -160.28771973, -113.77169037,\n",
" -126.60544586, -190.81988525], dtype=float32)"
"Array([-147.84033203, -207.02204895, -109.26074982, -243.80830383,\n",
" -163.02911377, -143.84848022, -160.28771973, -113.77169037,\n",
" -126.60544586, -190.81988525], dtype=float32)"
]
},
"execution_count": 17,

View File

@ -649,9 +649,8 @@ def make_array_from_single_device_arrays(
) -> ArrayImpl:
r"""Returns a ``jax.Array`` from a sequence of ``jax.Array``\s on a single device.
``jax.Array`` on a single device is analogous to a ``DeviceArray``. You can use
this function if you have already ``jax.device_put`` the value on a single
device and want to create a global Array. The smaller ``jax.Array``\s should be
You can use this function if you have already ``jax.device_put`` the value on
a single device and want to create a global Array. The smaller ``jax.Array``\s should be
addressable and belong to the current process.
Args:
@ -702,8 +701,7 @@ def make_array_from_single_device_arrays(
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True)
# TODO(phawkins): ideally the cast() could be checked. Revisit this after
# removing DeviceArray.
# TODO(phawkins): ideally the cast() could be checked.
return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays),
committed=True)

View File

@ -1620,7 +1620,7 @@ class ConcreteArray(ShapedArray):
def __eq__(self, other):
if (type(self) is type(other) and self.dtype == other.dtype
and self.shape == other.shape and self.weak_type == other.weak_type):
with eval_context(): # in case self.val is a DeviceArray
with eval_context(): # in case self.val is an Array
return (self.val == other.val).all()
else:
return False

View File

@ -30,21 +30,21 @@ SUPPORTED_DTYPES = frozenset({
def to_dlpack(x: Array, take_ownership: bool = False,
stream: int | None = None):
"""Returns a DLPack tensor that encapsulates a ``DeviceArray`` `x`.
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
Takes ownership of the contents of ``x``; leaves `x` in an invalid/deleted
Takes ownership of the contents of ``x``; leaves ``x`` in an invalid/deleted
state.
Args:
x: a ``DeviceArray``, on either CPU or GPU.
x: a :class:`~jax.Array`, on either CPU or GPU.
take_ownership: If ``True``, JAX hands ownership of the buffer to DLPack,
and the consumer is free to mutate the buffer; the JAX buffer acts as if
it were deleted. If ``False``, JAX retains ownership of the buffer; it is
undefined behavior if the DLPack consumer writes to a buffer that JAX
owns.
stream: optional platform-dependent stream to wait on until the buffer is
ready. This corresponds to the `stream` argument to __dlpack__ documented
in https://dmlc.github.io/dlpack/latest/python_spec.html.
ready. This corresponds to the `stream` argument to ``__dlpack__``
documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
"""
if not isinstance(x, array.ArrayImpl):
raise TypeError("Argument to to_dlpack must be a jax.Array, "
@ -64,9 +64,9 @@ def to_dlpack(x: Array, take_ownership: bool = False,
def from_dlpack(dlpack):
"""Returns a ``DeviceArray`` representation of a DLPack tensor.
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
The returned ``DeviceArray`` shares memory with ``dlpack``.
The returned :class:`~jax.Array` shares memory with ``dlpack``.
Args:
dlpack: a DLPack tensor, on either CPU or GPU.

View File

@ -65,7 +65,7 @@ class IndexedAxisSize:
replace = dataclasses.replace
# Jumble(aval=a:3 => f32[[3 1 4].a],
# data=DeviceArray([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
# data=Array([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
@dataclasses.dataclass(frozen=True)
class Jumble:
aval: JumbleTy

View File

@ -217,7 +217,7 @@ def local_aval_to_result_handler(
Returns:
A function for handling the Buffers that will eventually be produced
for this output. The function will return an object suitable for returning
to the user, e.g. a ShardedDeviceArray.
to the user, e.g. an Array.
"""
try:
return local_result_handlers[(type(aval))](aval, sharding, indices)
@ -247,7 +247,7 @@ def global_aval_to_result_handler(
Returns:
A function for handling the Buffers that will eventually be produced
for this output. The function will return an object suitable for returning
to the user, e.g. a ShardedDeviceArray.
to the user, e.g. an Array.
"""
try:
return global_result_handlers[type(aval)](
@ -1048,7 +1048,7 @@ class InputsHandler:
class ResultsHandler:
# `out_avals` is the `GlobalDeviceArray` global avals when using pjit or xmap
# `out_avals` is the `Array` global avals when using pjit or xmap
# with `config.parallel_functions_output_gda=True`. It is the local one
# otherwise, and also when using `pmap`.
__slots__ = ("handlers", "out_shardings", "out_avals")

View File

@ -47,7 +47,7 @@ zip, unsafe_zip = safe_zip, zip
### add method and operator overloads to arraylike classes
# We add operator overloads to DeviceArray and ShapedArray. These method and
# We add operator overloads to Array and ShapedArray. These method and
# operator overloads mainly just forward calls to the corresponding lax_numpy
# functions, which can themselves handle instances from any of these classes.
@ -240,7 +240,7 @@ def _view(arr: Array, dtype: DTypeLike = None, type: None = None) -> Array:
def _notimplemented_flat(self):
raise NotImplementedError("JAX DeviceArrays do not implement the arr.flat property: "
raise NotImplementedError("JAX Arrays do not implement the arr.flat property: "
"consider arr.flatten() instead.")
_accepted_binop_types = (int, float, complex, np.generic, np.ndarray, Array)
@ -308,8 +308,8 @@ def _multi_slice(arr: ArrayLike,
removed_dims: tuple[tuple[int, ...]]) -> list[Array]:
"""Extracts multiple slices from `arr`.
This is used to shard DeviceArray arguments to pmap. It's implemented as a
DeviceArray method here to avoid circular imports.
This is used to shard Array arguments to pmap. It's implemented as a
Array method here to avoid circular imports.
"""
results: list[Array] = []
for starts, limits, removed in zip(start_indices, limit_indices, removed_dims):
@ -746,7 +746,7 @@ def _set_tracer_aval_forwarding(tracer, exclude=()):
setattr(tracer, prop_name, _forward_property_to_aval(prop_name))
def _set_array_base_attributes(device_array, include=None, exclude=None):
# Forward operators, methods, and properties on DeviceArray to lax_numpy
# Forward operators, methods, and properties on Array to lax_numpy
# functions (with no Tracers involved; this forwarding is direct)
def maybe_setattr(attr_name, target):
if exclude is not None and attr_name in exclude:

View File

@ -321,7 +321,7 @@ def device_memory_profile(backend: Optional[str] = None) -> bytes:
"""Captures a JAX device memory profile as ``pprof``-format protocol buffer.
A device memory profile is a snapshot of the state of memory, that describes the JAX
:class:`jax.DeviceArray` and executable objects present in memory and their
:class:`~jax.Array` and executable objects present in memory and their
allocation sites.
For more information how to use the device memory profiler, see

View File

@ -13,12 +13,12 @@
# limitations under the License.
# A ShardingSpec describes at a high level how a logical array is sharded across
# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also
# describe how to shard inputs to a parallel computation). spec_to_indices()
# encodes exactly how a given ShardingSpec is translated to device buffers, i.e.
# how the sharded array is "laid out" across devices. Given a sequence of
# devices, we shard the data across the devices in row-major order, with
# replication treated as an extra inner dimension.
# devices (each array sharded with a `PmapSharding` has a ShardingSpec, and
# ShardingSpecs also describe how to shard inputs to a parallel computation).
# spec_to_indices() encodes exactly how a given ShardingSpec is translated to
# device buffers, i.e. how the sharded array is "laid out" across devices. Given
# a sequence of devices, we shard the data across the devices in row-major
# order, with replication treated as an extra inner dimension.
#
# For example, given the logical data array [1, 2, 3, 4], if we were to
# partition this array 4 ways with a replication factor of 2, for a total of 8
@ -233,8 +233,8 @@ def spec_to_indices(shape: Sequence[int],
"""Returns numpy-style indices corresponding to a sharding spec.
Each index describes a shard of the array. The order of the indices is the
same as the device_buffers of a ShardedDeviceArray (i.e. the data is laid out
row-major).
same as the device_buffers of a Array sharded using PmapSharding (i.e. the
data is laid out row-major).
Args:
shape: The shape of the logical array being sharded.

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GlobalDeviceArray serialization and deserialization."""
"""Array serialization and deserialization."""
import abc
import asyncio
@ -482,7 +482,7 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas
"""Responsible for serializing GDAs via TensorStore."""
def serialize(self, arrays, tensorstore_specs, *, on_commit_callback):
"""Serializes GlobalDeviceArrays or Arrays via TensorStore asynchronously.
"""Serializes Arrays or Arrays via TensorStore asynchronously.
TensorStore writes to a storage layer in 2 steps:
* Reading/copying from the source after which the source can be modified.
@ -494,7 +494,7 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas
finish in a separate thread allowing other computation to proceed.
Args:
arrays: GlobalDeviceArrays or Arrays that should be serialized.
arrays: Arrays or Arrays that should be serialized.
tensorstore_specs: TensorStore specs that are used to serialize GDAs or
Arrays.
on_commit_callback: This callback will be executed after all processes

View File

@ -57,7 +57,7 @@ TfConcreteFunction = Any
TfVal = jax2tf_internal.TfVal
# The platforms for which to use DLPack to avoid copying (only works on GPU
# and CPU at the moment, and only for DeviceArray). For CPU we don't need
# and CPU at the moment, and only for Array). For CPU we don't need
# DLPack, if we are careful.
_DLPACK_PLATFORMS = ("gpu",)
@ -335,7 +335,7 @@ def _call_tf_impl(*args_jax_flat, callable_flat_tf, **_):
arg_jax.dtype in dlpack.SUPPORTED_DTYPES):
arg_dlpack = jax.dlpack.to_dlpack(arg_jax, take_ownership=False)
return tf.experimental.dlpack.from_dlpack(arg_dlpack)
# The following avoids copies to the host on CPU, always for DeviceArray
# The following avoids copies to the host on CPU, always for Array
# and even for ndarray if they are sufficiently aligned.
# TODO(necula): on TPU this copies to the host!
return tf.constant(np.asarray(arg_jax))

View File

@ -1814,11 +1814,11 @@ def _not(x):
Numpy and JAX support bitwise not for booleans by applying a logical not!
This means that applying bitwise_not yields an unexpected result:
jnp.bitwise_not(jnp.array([True, False]))
>> DeviceArray([False, True], dtype=bool)
>> Array([False, True], dtype=bool)
if you assume that booleans are simply casted to integers.
jnp.bitwise_not(jnp.array([True, False]).astype(np.int32)).astype(bool)
>> DeviceArray([True, True], dtype=bool)
>> Array([True, True], dtype=bool)
"""
if x.dtype == tf.bool:
return tf.logical_not(x)

View File

@ -189,7 +189,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
f_tf = tf.function(lambda x: x + x)
self.assertEqual(f_tf(jnp.ones([])).numpy(), 2.)
# Test with ShardedDeviceArray.
# Test with a PmapSharding-sharded Array.
n = jax.local_device_count()
mk_sharded = lambda f: jax.pmap(lambda x: x)(f([n]))
f_tf = tf.function(lambda x: x)

View File

@ -226,7 +226,7 @@ class DimExprTest(tf_test_util.JaxToTfTestCase):
self.assertFalse(core.definitely_equal_one_of_dim(1, [2, b]))
self.assertFalse(core.definitely_equal_one_of_dim(3, []))
self.assertTrue(core.definitely_equal(1, jnp.add(0, 1))) # A DeviceArray
self.assertTrue(core.definitely_equal(1, jnp.add(0, 1))) # An Array
self.assertFalse(core.definitely_equal(1, "a"))
def test_poly_bounds(self):

View File

@ -2505,7 +2505,7 @@ class BCOO(JAXSparse):
@classmethod
def fromdense(cls, mat: Array, *, nse: int | None = None, index_dtype: DTypeLike = np.int32,
n_dense: int = 0, n_batch: int = 0) -> BCOO:
"""Create a BCOO array from a (dense) :class:`DeviceArray`."""
"""Create a BCOO array from a (dense) :class:`~jax.Array`."""
return bcoo_fromdense(
mat, nse=nse, index_dtype=index_dtype, n_dense=n_dense, n_batch=n_batch)

View File

@ -848,7 +848,7 @@ class BCSR(JAXSparse):
@classmethod
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32, n_dense=0,
n_batch=0):
"""Create a BCSR array from a (dense) :class:`DeviceArray`."""
"""Create a BCSR array from a (dense) :class:`Array`."""
return bcsr_fromdense(mat, nse=nse, index_dtype=index_dtype,
n_dense=n_dense, n_batch=n_batch)

View File

@ -2669,7 +2669,7 @@ class APITest(jtu.JaxTestCase):
error_text = "float0s do not support any operations by design"
with self.assertRaisesRegex(TypeError, error_text):
# dispatch via DeviceArray
# dispatch via Array
_ = float0_array + jnp.zeros(())
with self.assertRaisesRegex(TypeError, error_text):

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for GlobalDeviceArray."""
"""Tests for Array."""
import contextlib
import math

View File

@ -661,7 +661,7 @@ class TestPromotionTables(jtu.JaxTestCase):
for dtype in all_dtypes
for weak_type in [True, False]
)
def testDeviceArrayRepr(self, dtype, weak_type):
def testArrayRepr(self, dtype, weak_type):
val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
rep = repr(val)
self.assertStartsWith(rep, 'Array(')

View File

@ -2567,7 +2567,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
# TODO(mattjj): make the numpy.ndarray test pass w/ remat
raise unittest.SkipTest("new-remat-of-scan doesn't convert numpy.ndarray")
x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not DeviceArray
x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not Array
_, vjp_fun = jax.vjp(cumprod, x)
*_, ext_res = vjp_fun.args[0].args[0]
self.assertIsInstance(ext_res, jax.Array)

View File

@ -3196,7 +3196,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testIssue121(self):
assert not np.isscalar(jnp.array(3))
def testArrayOutputsDeviceArrays(self):
def testArrayOutputsArrays(self):
assert type(jnp.array([])) is array.ArrayImpl
assert type(jnp.array(np.array([]))) is array.ArrayImpl
@ -3206,10 +3206,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
assert type(jnp.array(NDArrayLike())) is array.ArrayImpl
# NOTE(mattjj): disabled b/c __array__ must produce ndarrays
# class DeviceArrayLike:
# class ArrayLike:
# def __array__(self, dtype=None):
# return jnp.array([], dtype=dtype)
# assert xla.type_is_device_array(jnp.array(DeviceArrayLike()))
# assert xla.type_is_device_array(jnp.array(ArrayLike()))
def testArrayMethod(self):
class arraylike:

View File

@ -320,8 +320,6 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
self.assertEqual(y[0], jax.device_count())
print(y)
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
# since `GlobalDeviceArray` is going to be deprecated in the future
def test_pjit_gda_multi_input_multi_output(self):
jax.distributed.initialize()
global_mesh = jtu.create_global_mesh((8, 2), ("x", "y"))
@ -370,8 +368,6 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
np.testing.assert_array_equal(np.asarray(s.data),
global_input_data[s.index])
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
# since `GlobalDeviceArray` is going to be deprecated in the future
def test_pjit_gda_non_contiguous_mesh(self):
jax.distributed.initialize()
devices = self.sorted_devices()
@ -428,8 +424,6 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
np.testing.assert_array_equal(np.asarray(s.data),
global_input_data[expected_index])
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
# since `GlobalDeviceArray` is going to be deprecated in the future
def test_pjit_gda_non_contiguous_mesh_2d(self):
jax.distributed.initialize()
global_mesh = self.create_2d_non_contiguous_mesh()
@ -504,8 +498,6 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
# Fully replicated values + GDA allows a non-contiguous mesh.
out1, out2 = f(global_input_data, gda2)
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
# since `GlobalDeviceArray` is going to be deprecated in the future
def test_pjit_gda_non_contiguous_mesh_2d_aot(self):
jax.distributed.initialize()
global_mesh = self.create_2d_non_contiguous_mesh()
@ -531,8 +523,6 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
self.assertEqual(out1.shape, (8, 2))
self.assertEqual(out2.shape, (8, 2))
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
# since `GlobalDeviceArray` is going to be deprecated in the future
def test_pjit_gda_eval_shape(self):
jax.distributed.initialize()

View File

@ -98,7 +98,7 @@ class CloudpickleTest(jtu.JaxTestCase):
class PickleTest(jtu.JaxTestCase):
def testPickleOfDeviceArray(self):
def testPickleOfArray(self):
x = jnp.arange(10.0)
s = pickle.dumps(x)
y = pickle.loads(s)
@ -106,7 +106,7 @@ class PickleTest(jtu.JaxTestCase):
self.assertIsInstance(y, type(x))
self.assertEqual(x.aval, y.aval)
def testPickleOfDeviceArrayWeakType(self):
def testPickleOfArrayWeakType(self):
x = jnp.array(4.0)
self.assertEqual(x.aval.weak_type, True)
s = pickle.dumps(x)

View File

@ -842,7 +842,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertNotIsInstance(z, np.ndarray)
self.assertAllClose(z, 2 * 2 * x, check_dtypes=False)
# test that we can pass in a regular DeviceArray
# test that we can pass in a regular Array
y = f(device_put(x))
self.assertIsInstance(y, array.ArrayImpl)
self.assertAllClose(y, 2 * x, check_dtypes=False)