mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Replace references to DeviceArray with Array.
A number of stale references are lurking in our documentation.
This commit is contained in:
parent
97af33c4d1
commit
2c32660a8f
@ -251,9 +251,9 @@
|
||||
"colab_type": "text"
|
||||
},
|
||||
"source": [
|
||||
"A `ShardedDeviceArray` is effectively an `ndarray` subclass, but it's stored in pieces spread across the memory of multiple devices. Results from `pmap` functions are left sharded in device memory so that they can be operated on by subsequent `pmap` functions without moving data around, at least in some cases. But these results logically appear just like a single array.\n",
|
||||
"A sharded `Array` is effectively an `ndarray` subclass, but it's stored in pieces spread across the memory of multiple devices. Results from `pmap` functions are left sharded in device memory so that they can be operated on by subsequent `pmap` functions without moving data around, at least in some cases. But these results logically appear just like a single array.\n",
|
||||
"\n",
|
||||
"When you call a non-`pmap` function on a `ShardedDeviceArray`, like a standard `jax.numpy` function, communication happens behind the scenes to bring the values to one device (or back to the host in the case of the matplotlib function above):"
|
||||
"When you call a non-`pmap` function on an `Array`, like a standard `jax.numpy` function, communication happens behind the scenes to bring the values to one device (or back to the host in the case of the matplotlib function above):"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -2344,8 +2344,8 @@
|
||||
"One piece missing is device memory persistence for arrays. That is, we've\n",
|
||||
"defined `handle_result` to transfer results back to CPU memory as NumPy\n",
|
||||
"arrays, but it's often preferable to avoid transferring results just to\n",
|
||||
"transfer them back for the next operation. We can do that by introducing a\n",
|
||||
"`DeviceArray` class, which can wrap XLA buffers and otherwise duck-type\n",
|
||||
"transfer them back for the next operation. We can do that by introducing an\n",
|
||||
"`Array` class, which can wrap XLA buffers and otherwise duck-type\n",
|
||||
"`numpy.ndarray`s:"
|
||||
]
|
||||
},
|
||||
@ -2356,9 +2356,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def handle_result(aval: ShapedArray, buf): # noqa: F811\n",
|
||||
" return DeviceArray(aval, buf)\n",
|
||||
" return Array(aval, buf)\n",
|
||||
"\n",
|
||||
"class DeviceArray:\n",
|
||||
"class Array:\n",
|
||||
" buf: Any\n",
|
||||
" aval: ShapedArray\n",
|
||||
"\n",
|
||||
@ -2381,9 +2381,9 @@
|
||||
" _rmul = staticmethod(mul)\n",
|
||||
" _gt = staticmethod(greater)\n",
|
||||
" _lt = staticmethod(less)\n",
|
||||
"input_handlers[DeviceArray] = lambda x: x.buf\n",
|
||||
"input_handlers[Array] = lambda x: x.buf\n",
|
||||
"\n",
|
||||
"jax_types.add(DeviceArray)"
|
||||
"jax_types.add(Array)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1822,15 +1822,15 @@ print(ys)
|
||||
One piece missing is device memory persistence for arrays. That is, we've
|
||||
defined `handle_result` to transfer results back to CPU memory as NumPy
|
||||
arrays, but it's often preferable to avoid transferring results just to
|
||||
transfer them back for the next operation. We can do that by introducing a
|
||||
`DeviceArray` class, which can wrap XLA buffers and otherwise duck-type
|
||||
transfer them back for the next operation. We can do that by introducing an
|
||||
`Array` class, which can wrap XLA buffers and otherwise duck-type
|
||||
`numpy.ndarray`s:
|
||||
|
||||
```{code-cell}
|
||||
def handle_result(aval: ShapedArray, buf): # noqa: F811
|
||||
return DeviceArray(aval, buf)
|
||||
return Array(aval, buf)
|
||||
|
||||
class DeviceArray:
|
||||
class Array:
|
||||
buf: Any
|
||||
aval: ShapedArray
|
||||
|
||||
@ -1853,9 +1853,9 @@ class DeviceArray:
|
||||
_rmul = staticmethod(mul)
|
||||
_gt = staticmethod(greater)
|
||||
_lt = staticmethod(less)
|
||||
input_handlers[DeviceArray] = lambda x: x.buf
|
||||
input_handlers[Array] = lambda x: x.buf
|
||||
|
||||
jax_types.add(DeviceArray)
|
||||
jax_types.add(Array)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
|
@ -1813,15 +1813,15 @@ print(ys)
|
||||
# One piece missing is device memory persistence for arrays. That is, we've
|
||||
# defined `handle_result` to transfer results back to CPU memory as NumPy
|
||||
# arrays, but it's often preferable to avoid transferring results just to
|
||||
# transfer them back for the next operation. We can do that by introducing a
|
||||
# `DeviceArray` class, which can wrap XLA buffers and otherwise duck-type
|
||||
# transfer them back for the next operation. We can do that by introducing an
|
||||
# `Array` class, which can wrap XLA buffers and otherwise duck-type
|
||||
# `numpy.ndarray`s:
|
||||
|
||||
# +
|
||||
def handle_result(aval: ShapedArray, buf): # noqa: F811
|
||||
return DeviceArray(aval, buf)
|
||||
return Array(aval, buf)
|
||||
|
||||
class DeviceArray:
|
||||
class Array:
|
||||
buf: Any
|
||||
aval: ShapedArray
|
||||
|
||||
@ -1844,9 +1844,9 @@ class DeviceArray:
|
||||
_rmul = staticmethod(mul)
|
||||
_gt = staticmethod(greater)
|
||||
_lt = staticmethod(less)
|
||||
input_handlers[DeviceArray] = lambda x: x.buf
|
||||
input_handlers[Array] = lambda x: x.buf
|
||||
|
||||
jax_types.add(DeviceArray)
|
||||
jax_types.add(Array)
|
||||
|
||||
|
||||
# +
|
||||
|
@ -68,7 +68,7 @@
|
||||
"source": [
|
||||
"So far, everything is just like NumPy. A big appeal of JAX is that you don't need to learn a new API. Many common NumPy programs would run just as well in JAX if you substitute `np` for `jnp`. However, there are some important differences which we touch on at the end of this section.\n",
|
||||
"\n",
|
||||
"You can notice the first difference if you check the type of `x`. It is a variable of type `DeviceArray`, which is the way JAX represents arrays."
|
||||
"You can notice the first difference if you check the type of `x`. It is a variable of type `Array`, which is the way JAX represents arrays."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -81,7 +81,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)"
|
||||
"Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
@ -277,8 +277,8 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n",
|
||||
" DeviceArray([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))"
|
||||
"(Array([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n",
|
||||
" Array([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
@ -338,8 +338,8 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(DeviceArray(0.03999995, dtype=float32),\n",
|
||||
" DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))"
|
||||
"(Array(0.03999995, dtype=float32),\n",
|
||||
" Array([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
@ -395,7 +395,7 @@
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mFilteredStackTrace\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m<ipython-input-9-7433a86e7375>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquared_error_with_aux\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||
"\u001b[0;31mFilteredStackTrace\u001b[0m: TypeError: Gradient only defined for scalar-output functions. Output was (DeviceArray(0.03999995, dtype=float32), DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32)).\n\nThe stack trace above excludes JAX-internal frames."
|
||||
"\u001b[0;31mFilteredStackTrace\u001b[0m: TypeError: Gradient only defined for scalar-output functions. Output was (Array(0.03999995, dtype=float32), Array([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32)).\n\nThe stack trace above excludes JAX-internal frames."
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -425,8 +425,8 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n",
|
||||
" DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32))"
|
||||
"(Array([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n",
|
||||
" Array([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32))"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
@ -530,7 +530,7 @@
|
||||
"\u001b[0;32m<ipython-input-12-709e2d7ddd3f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0min_place_modify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Raises error when we cast input to jnp.ndarray\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||
"\u001b[0;32m<ipython-input-11-fce65eb843c7>\u001b[0m in \u001b[0;36min_place_modify\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0min_place_modify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m123\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py\u001b[0m in \u001b[0;36m_unimplemented_setitem\u001b[0;34m(self, i, x)\u001b[0m\n\u001b[1;32m 6594\u001b[0m \u001b[0;34m\"or another .at[] method: \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6595\u001b[0m \"https://jax.readthedocs.io/en/latest/jax.ops.html\")\n\u001b[0;32m-> 6596\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6597\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6598\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_operator_round\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumber\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mndigits\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mTypeError\u001b[0m: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html"
|
||||
"\u001b[0;31mTypeError\u001b[0m: '<class 'jaxlib.xla_extension.Array'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -557,7 +557,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([123, 2, 3], dtype=int32)"
|
||||
"Array([123, 2, 3], dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
@ -594,7 +594,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([1, 2, 3], dtype=int32)"
|
||||
"Array([1, 2, 3], dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
|
@ -45,7 +45,7 @@ print(x)
|
||||
|
||||
So far, everything is just like NumPy. A big appeal of JAX is that you don't need to learn a new API. Many common NumPy programs would run just as well in JAX if you substitute `np` for `jnp`. However, there are some important differences which we touch on at the end of this section.
|
||||
|
||||
You can notice the first difference if you check the type of `x`. It is a variable of type `DeviceArray`, which is the way JAX represents arrays.
|
||||
You can notice the first difference if you check the type of `x`. It is a variable of type `Array`, which is the way JAX represents arrays.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: 3fLtgPUAn7mi
|
||||
|
@ -401,7 +401,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(30, dtype=int32, weak_type=True)"
|
||||
"Array(30, dtype=int32, weak_type=True)"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
|
@ -37,7 +37,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([11., 20., 29.], dtype=float32)"
|
||||
"Array([11., 20., 29.], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 1,
|
||||
@ -104,7 +104,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([[11., 20., 29.],\n",
|
||||
"Array([[11., 20., 29.],\n",
|
||||
" [11., 20., 29.]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
@ -149,7 +149,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([[11., 20., 29.],\n",
|
||||
"Array([[11., 20., 29.],\n",
|
||||
" [11., 20., 29.]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
@ -201,7 +201,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([[11., 20., 29.],\n",
|
||||
"Array([[11., 20., 29.],\n",
|
||||
" [11., 20., 29.]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
@ -240,7 +240,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([[11., 11.],\n",
|
||||
"Array([[11., 11.],\n",
|
||||
" [20., 20.],\n",
|
||||
" [29., 29.]], dtype=float32)"
|
||||
]
|
||||
@ -281,7 +281,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([[11., 20., 29.],\n",
|
||||
"Array([[11., 20., 29.],\n",
|
||||
" [11., 20., 29.]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
@ -320,7 +320,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([[11., 20., 29.],\n",
|
||||
"Array([[11., 20., 29.],\n",
|
||||
" [11., 20., 29.]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
|
@ -175,9 +175,9 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([[2., 0., 0.],\n",
|
||||
" [0., 2., 0.],\n",
|
||||
" [0., 0., 2.]], dtype=float32)"
|
||||
"Array([[2., 0., 0.],\n",
|
||||
" [0., 2., 0.],\n",
|
||||
" [0., 0., 2.]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
@ -312,7 +312,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([ 2.4, -2.4, 2.4], dtype=float32)"
|
||||
"Array([ 2.4, -2.4, 2.4], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
@ -356,7 +356,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([-2.4, -4.8, 2.4], dtype=float32)"
|
||||
"Array([-2.4, -4.8, 2.4], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
@ -459,8 +459,8 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([[-2.4, -4.8, 2.4],\n",
|
||||
" [-2.4, -4.8, 2.4]], dtype=float32)"
|
||||
"Array([[-2.4, -4.8, 2.4],\n",
|
||||
" [-2.4, -4.8, 2.4]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
@ -503,7 +503,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([-2.4, -4.8, 2.4], dtype=float32)"
|
||||
"Array([-2.4, -4.8, 2.4], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
@ -548,8 +548,8 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([[-2.4, -4.8, 2.4],\n",
|
||||
" [-2.4, -4.8, 2.4]], dtype=float32)"
|
||||
"Array([[-2.4, -4.8, 2.4],\n",
|
||||
" [-2.4, -4.8, 2.4]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
@ -586,8 +586,8 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([[-2.4, -4.8, 2.4],\n",
|
||||
" [-2.4, -4.8, 2.4]], dtype=float32)"
|
||||
"Array([[-2.4, -4.8, 2.4],\n",
|
||||
" [-2.4, -4.8, 2.4]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
@ -623,8 +623,8 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([[-2.4, -4.8, 2.4],\n",
|
||||
" [-2.4, -4.8, 2.4]], dtype=float32)"
|
||||
"Array([[-2.4, -4.8, 2.4],\n",
|
||||
" [-2.4, -4.8, 2.4]], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
|
@ -347,7 +347,7 @@
|
||||
"evalue": "ignored",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html\n"
|
||||
"\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -587,7 +587,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(9, dtype=int32)"
|
||||
"Array(9, dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
@ -754,7 +754,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(45, dtype=int32)"
|
||||
"Array(45, dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
@ -848,7 +848,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(45, dtype=int32)"
|
||||
"Array(45, dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 19,
|
||||
@ -1020,7 +1020,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([0, 0], dtype=uint32)"
|
||||
"Array([0, 0], dtype=uint32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 23,
|
||||
@ -1408,7 +1408,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(5., dtype=float32)"
|
||||
"Array(5., dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 33,
|
||||
@ -1553,7 +1553,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(4, dtype=int32, weak_type=True)"
|
||||
"Array(4, dtype=int32, weak_type=True)"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
@ -1616,7 +1616,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([-1.], dtype=float32)"
|
||||
"Array([-1.], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 38,
|
||||
@ -1689,7 +1689,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(10, dtype=int32, weak_type=True)"
|
||||
"Array(10, dtype=int32, weak_type=True)"
|
||||
]
|
||||
},
|
||||
"execution_count": 39,
|
||||
@ -1733,7 +1733,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(45, dtype=int32, weak_type=True)"
|
||||
"Array(45, dtype=int32, weak_type=True)"
|
||||
]
|
||||
},
|
||||
"execution_count": 40,
|
||||
@ -2000,7 +2000,7 @@
|
||||
" 104 if np.any(np.isnan(py_val)):\n",
|
||||
"--> 105 raise FloatingPointError(\"invalid value\")\n",
|
||||
" 106 else:\n",
|
||||
" 107 return DeviceArray(device_buffer, *result_shape)\n",
|
||||
" 107 return Array(device_buffer, *result_shape)\n",
|
||||
"\n",
|
||||
"FloatingPointError: invalid value\n",
|
||||
"```"
|
||||
@ -2222,7 +2222,7 @@
|
||||
" array([254, 255, 0, 1], dtype=uint8)\n",
|
||||
"\n",
|
||||
" >>> jnp.arange(254.0, 258.0).astype('uint8')\n",
|
||||
" DeviceArray([254, 255, 255, 255], dtype=uint8)\n",
|
||||
" Array([254, 255, 255, 255], dtype=uint8)\n",
|
||||
" ```\n",
|
||||
" This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n",
|
||||
"\n",
|
||||
|
@ -987,7 +987,7 @@ FloatingPointError Traceback (most recent call last)
|
||||
104 if np.any(np.isnan(py_val)):
|
||||
--> 105 raise FloatingPointError("invalid value")
|
||||
106 else:
|
||||
107 return DeviceArray(device_buffer, *result_shape)
|
||||
107 return Array(device_buffer, *result_shape)
|
||||
|
||||
FloatingPointError: invalid value
|
||||
```
|
||||
@ -1142,7 +1142,7 @@ Many such cases are discussed in detail in the sections above; here we list seve
|
||||
array([254, 255, 0, 1], dtype=uint8)
|
||||
|
||||
>>> jnp.arange(254.0, 258.0).astype('uint8')
|
||||
DeviceArray([254, 255, 255, 255], dtype=uint8)
|
||||
Array([254, 255, 255, 255], dtype=uint8)
|
||||
```
|
||||
This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.
|
||||
|
||||
|
@ -234,7 +234,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(3.0485873, dtype=float32)"
|
||||
"Array(3.0485873, dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
@ -1578,7 +1578,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(-0.14112, dtype=float32)"
|
||||
"Array(-0.14112, dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 50,
|
||||
@ -1901,7 +1901,7 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"called f_bwd!\n",
|
||||
"(DeviceArray(-0.9899925, dtype=float32),)\n"
|
||||
"(Array(-0.9899925, dtype=float32),)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -2013,9 +2013,9 @@
|
||||
"> <ipython-input-113-b19a2dc1abf7>(12)debug_bwd()\n",
|
||||
"-> return g\n",
|
||||
"(Pdb) p x\n",
|
||||
"DeviceArray(9., dtype=float32)\n",
|
||||
"Array(9., dtype=float32)\n",
|
||||
"(Pdb) p g\n",
|
||||
"DeviceArray(-0.91113025, dtype=float32)\n",
|
||||
"Array(-0.91113025, dtype=float32)\n",
|
||||
"(Pdb) q\n",
|
||||
"```"
|
||||
]
|
||||
@ -2085,7 +2085,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'a': 1.0, 'b': (DeviceArray(0.841471, dtype=float32), DeviceArray(-0.4161468, dtype=float32))}\n"
|
||||
"{'a': 1.0, 'b': (Array(0.841471, dtype=float32), Array(-0.4161468, dtype=float32))}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -2107,7 +2107,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Point(x=DeviceArray(2.5403023, dtype=float32), y=array(0., dtype=float32))\n"
|
||||
"Point(x=Array(2.5403023, dtype=float32), y=array(0., dtype=float32))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -2166,7 +2166,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'a': 1.0, 'b': (DeviceArray(0.841471, dtype=float32), DeviceArray(-0.4161468, dtype=float32))}\n"
|
||||
"{'a': 1.0, 'b': (Array(0.841471, dtype=float32), Array(-0.4161468, dtype=float32))}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -2188,7 +2188,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Point(x=DeviceArray(2.5403023, dtype=float32), y=DeviceArray(-0., dtype=float32))\n"
|
||||
"Point(x=Array(2.5403023, dtype=float32), y=Array(-0., dtype=float32))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -1036,9 +1036,9 @@ jax.grad(foo)(3.)
|
||||
> <ipython-input-113-b19a2dc1abf7>(12)debug_bwd()
|
||||
-> return g
|
||||
(Pdb) p x
|
||||
DeviceArray(9., dtype=float32)
|
||||
Array(9., dtype=float32)
|
||||
(Pdb) p g
|
||||
DeviceArray(-0.91113025, dtype=float32)
|
||||
Array(-0.91113025, dtype=float32)
|
||||
(Pdb) q
|
||||
```
|
||||
|
||||
|
@ -236,7 +236,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'W': DeviceArray([-0.16965576, -0.8774645 , -1.4901344 ], dtype=float32), 'b': DeviceArray(-0.29227236, dtype=float32)}\n"
|
||||
"{'W': Array([-0.16965576, -0.8774645 , -1.4901344 ], dtype=float32), 'b': Array(-0.29227236, dtype=float32)}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1204,7 +1204,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(DeviceArray(3.1415927, dtype=float32),)\n"
|
||||
"(Array(3.1415927, dtype=float32),)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1470,7 +1470,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(6.-8.j, dtype=complex64)"
|
||||
"Array(6.-8.j, dtype=complex64)"
|
||||
]
|
||||
},
|
||||
"execution_count": 31,
|
||||
@ -1511,7 +1511,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(-27.034945-3.8511531j, dtype=complex64)"
|
||||
"Array(-27.034945-3.8511531j, dtype=complex64)"
|
||||
]
|
||||
},
|
||||
"execution_count": 32,
|
||||
@ -1549,7 +1549,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(1.-0.j, dtype=complex64)"
|
||||
"Array(1.-0.j, dtype=complex64)"
|
||||
]
|
||||
},
|
||||
"execution_count": 33,
|
||||
@ -1602,12 +1602,12 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([[-0.75342447 +0.j , -3.0509021 -10.940544j ,\n",
|
||||
" 5.989684 +3.5422976j],\n",
|
||||
" [-3.0509021 +10.940544j , -8.904487 +0.j ,\n",
|
||||
" -5.1351547 -6.5593696j],\n",
|
||||
" [ 5.989684 -3.5422976j, -5.1351547 +6.5593696j,\n",
|
||||
" 0.01320434 +0.j ]], dtype=complex64)"
|
||||
"Array([[-0.75342447 +0.j , -3.0509021 -10.940544j ,\n",
|
||||
" 5.989684 +3.5422976j],\n",
|
||||
" [-3.0509021 +10.940544j , -8.904487 +0.j ,\n",
|
||||
" -5.1351547 -6.5593696j],\n",
|
||||
" [ 5.989684 -3.5422976j, -5.1351547 +6.5593696j,\n",
|
||||
" 0.01320434 +0.j ]], dtype=complex64)"
|
||||
]
|
||||
},
|
||||
"execution_count": 34,
|
||||
|
@ -1071,7 +1071,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(-0.4003078, dtype=float32, weak_type=True)"
|
||||
"Array(-0.4003078, dtype=float32, weak_type=True)"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
|
@ -170,7 +170,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"jax.interpreters.xla._DeviceArray"
|
||||
"jaxlib.xla_extension.ArrayImpl"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
@ -248,7 +248,7 @@
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m<ipython-input-7-6b90817377fe>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# JAX: immutable arrays\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||
"\u001b[0;31mTypeError\u001b[0m: '<class 'jax.interpreters.xla._DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?"
|
||||
"\u001b[0;31mTypeError\u001b[0m: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -327,7 +327,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(2., dtype=float32)"
|
||||
"Array(2., dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
@ -390,7 +390,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(2., dtype=float32)"
|
||||
"Array(2., dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
@ -426,7 +426,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)"
|
||||
"Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
@ -462,7 +462,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)"
|
||||
"Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
@ -638,7 +638,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)"
|
||||
"Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
@ -739,7 +739,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([0.25773212, 5.3623195 , 5.4032435 ], dtype=float32)"
|
||||
"Array([0.25773212, 5.3623195 , 5.4032435 ], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 20,
|
||||
@ -788,7 +788,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([1.4344584, 4.3004413, 7.9897013], dtype=float32)"
|
||||
"Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
@ -908,7 +908,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(-1, dtype=int32)"
|
||||
"Array(-1, dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
@ -948,7 +948,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(1, dtype=int32)"
|
||||
"Array(1, dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 25,
|
||||
@ -1086,7 +1086,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)"
|
||||
"Array([1., 1., 1., 1., 1., 1.], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 28,
|
||||
|
@ -145,7 +145,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(-213.23558, dtype=float32)"
|
||||
"Array(-213.23558, dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
@ -229,9 +229,9 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([-147.84033203, -207.02204895, -109.26074982, -243.80830383,\n",
|
||||
" -163.02911377, -143.84848022, -160.28771973, -113.77169037,\n",
|
||||
" -126.60544586, -190.81988525], dtype=float32)"
|
||||
"Array([-147.84033203, -207.02204895, -109.26074982, -243.80830383,\n",
|
||||
" -163.02911377, -143.84848022, -160.28771973, -113.77169037,\n",
|
||||
" -126.60544586, -190.81988525], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
@ -270,9 +270,9 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray([-147.84033203, -207.02204895, -109.26074982, -243.80830383,\n",
|
||||
" -163.02911377, -143.84848022, -160.28771973, -113.77169037,\n",
|
||||
" -126.60544586, -190.81988525], dtype=float32)"
|
||||
"Array([-147.84033203, -207.02204895, -109.26074982, -243.80830383,\n",
|
||||
" -163.02911377, -143.84848022, -160.28771973, -113.77169037,\n",
|
||||
" -126.60544586, -190.81988525], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
|
@ -649,9 +649,8 @@ def make_array_from_single_device_arrays(
|
||||
) -> ArrayImpl:
|
||||
r"""Returns a ``jax.Array`` from a sequence of ``jax.Array``\s on a single device.
|
||||
|
||||
``jax.Array`` on a single device is analogous to a ``DeviceArray``. You can use
|
||||
this function if you have already ``jax.device_put`` the value on a single
|
||||
device and want to create a global Array. The smaller ``jax.Array``\s should be
|
||||
You can use this function if you have already ``jax.device_put`` the value on
|
||||
a single device and want to create a global Array. The smaller ``jax.Array``\s should be
|
||||
addressable and belong to the current process.
|
||||
|
||||
Args:
|
||||
@ -702,8 +701,7 @@ def make_array_from_single_device_arrays(
|
||||
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
|
||||
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
||||
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True)
|
||||
# TODO(phawkins): ideally the cast() could be checked. Revisit this after
|
||||
# removing DeviceArray.
|
||||
# TODO(phawkins): ideally the cast() could be checked.
|
||||
return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays),
|
||||
committed=True)
|
||||
|
||||
|
@ -1620,7 +1620,7 @@ class ConcreteArray(ShapedArray):
|
||||
def __eq__(self, other):
|
||||
if (type(self) is type(other) and self.dtype == other.dtype
|
||||
and self.shape == other.shape and self.weak_type == other.weak_type):
|
||||
with eval_context(): # in case self.val is a DeviceArray
|
||||
with eval_context(): # in case self.val is an Array
|
||||
return (self.val == other.val).all()
|
||||
else:
|
||||
return False
|
||||
|
@ -30,21 +30,21 @@ SUPPORTED_DTYPES = frozenset({
|
||||
|
||||
def to_dlpack(x: Array, take_ownership: bool = False,
|
||||
stream: int | None = None):
|
||||
"""Returns a DLPack tensor that encapsulates a ``DeviceArray`` `x`.
|
||||
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.
|
||||
|
||||
Takes ownership of the contents of ``x``; leaves `x` in an invalid/deleted
|
||||
Takes ownership of the contents of ``x``; leaves ``x`` in an invalid/deleted
|
||||
state.
|
||||
|
||||
Args:
|
||||
x: a ``DeviceArray``, on either CPU or GPU.
|
||||
x: a :class:`~jax.Array`, on either CPU or GPU.
|
||||
take_ownership: If ``True``, JAX hands ownership of the buffer to DLPack,
|
||||
and the consumer is free to mutate the buffer; the JAX buffer acts as if
|
||||
it were deleted. If ``False``, JAX retains ownership of the buffer; it is
|
||||
undefined behavior if the DLPack consumer writes to a buffer that JAX
|
||||
owns.
|
||||
stream: optional platform-dependent stream to wait on until the buffer is
|
||||
ready. This corresponds to the `stream` argument to __dlpack__ documented
|
||||
in https://dmlc.github.io/dlpack/latest/python_spec.html.
|
||||
ready. This corresponds to the `stream` argument to ``__dlpack__``
|
||||
documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
|
||||
"""
|
||||
if not isinstance(x, array.ArrayImpl):
|
||||
raise TypeError("Argument to to_dlpack must be a jax.Array, "
|
||||
@ -64,9 +64,9 @@ def to_dlpack(x: Array, take_ownership: bool = False,
|
||||
|
||||
|
||||
def from_dlpack(dlpack):
|
||||
"""Returns a ``DeviceArray`` representation of a DLPack tensor.
|
||||
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
|
||||
|
||||
The returned ``DeviceArray`` shares memory with ``dlpack``.
|
||||
The returned :class:`~jax.Array` shares memory with ``dlpack``.
|
||||
|
||||
Args:
|
||||
dlpack: a DLPack tensor, on either CPU or GPU.
|
||||
|
@ -65,7 +65,7 @@ class IndexedAxisSize:
|
||||
replace = dataclasses.replace
|
||||
|
||||
# Jumble(aval=a:3 => f32[[3 1 4].a],
|
||||
# data=DeviceArray([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
|
||||
# data=Array([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Jumble:
|
||||
aval: JumbleTy
|
||||
|
@ -217,7 +217,7 @@ def local_aval_to_result_handler(
|
||||
Returns:
|
||||
A function for handling the Buffers that will eventually be produced
|
||||
for this output. The function will return an object suitable for returning
|
||||
to the user, e.g. a ShardedDeviceArray.
|
||||
to the user, e.g. an Array.
|
||||
"""
|
||||
try:
|
||||
return local_result_handlers[(type(aval))](aval, sharding, indices)
|
||||
@ -247,7 +247,7 @@ def global_aval_to_result_handler(
|
||||
Returns:
|
||||
A function for handling the Buffers that will eventually be produced
|
||||
for this output. The function will return an object suitable for returning
|
||||
to the user, e.g. a ShardedDeviceArray.
|
||||
to the user, e.g. an Array.
|
||||
"""
|
||||
try:
|
||||
return global_result_handlers[type(aval)](
|
||||
@ -1048,7 +1048,7 @@ class InputsHandler:
|
||||
|
||||
|
||||
class ResultsHandler:
|
||||
# `out_avals` is the `GlobalDeviceArray` global avals when using pjit or xmap
|
||||
# `out_avals` is the `Array` global avals when using pjit or xmap
|
||||
# with `config.parallel_functions_output_gda=True`. It is the local one
|
||||
# otherwise, and also when using `pmap`.
|
||||
__slots__ = ("handlers", "out_shardings", "out_avals")
|
||||
|
@ -47,7 +47,7 @@ zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
### add method and operator overloads to arraylike classes
|
||||
|
||||
# We add operator overloads to DeviceArray and ShapedArray. These method and
|
||||
# We add operator overloads to Array and ShapedArray. These method and
|
||||
# operator overloads mainly just forward calls to the corresponding lax_numpy
|
||||
# functions, which can themselves handle instances from any of these classes.
|
||||
|
||||
@ -240,7 +240,7 @@ def _view(arr: Array, dtype: DTypeLike = None, type: None = None) -> Array:
|
||||
|
||||
|
||||
def _notimplemented_flat(self):
|
||||
raise NotImplementedError("JAX DeviceArrays do not implement the arr.flat property: "
|
||||
raise NotImplementedError("JAX Arrays do not implement the arr.flat property: "
|
||||
"consider arr.flatten() instead.")
|
||||
|
||||
_accepted_binop_types = (int, float, complex, np.generic, np.ndarray, Array)
|
||||
@ -308,8 +308,8 @@ def _multi_slice(arr: ArrayLike,
|
||||
removed_dims: tuple[tuple[int, ...]]) -> list[Array]:
|
||||
"""Extracts multiple slices from `arr`.
|
||||
|
||||
This is used to shard DeviceArray arguments to pmap. It's implemented as a
|
||||
DeviceArray method here to avoid circular imports.
|
||||
This is used to shard Array arguments to pmap. It's implemented as a
|
||||
Array method here to avoid circular imports.
|
||||
"""
|
||||
results: list[Array] = []
|
||||
for starts, limits, removed in zip(start_indices, limit_indices, removed_dims):
|
||||
@ -746,7 +746,7 @@ def _set_tracer_aval_forwarding(tracer, exclude=()):
|
||||
setattr(tracer, prop_name, _forward_property_to_aval(prop_name))
|
||||
|
||||
def _set_array_base_attributes(device_array, include=None, exclude=None):
|
||||
# Forward operators, methods, and properties on DeviceArray to lax_numpy
|
||||
# Forward operators, methods, and properties on Array to lax_numpy
|
||||
# functions (with no Tracers involved; this forwarding is direct)
|
||||
def maybe_setattr(attr_name, target):
|
||||
if exclude is not None and attr_name in exclude:
|
||||
|
@ -321,7 +321,7 @@ def device_memory_profile(backend: Optional[str] = None) -> bytes:
|
||||
"""Captures a JAX device memory profile as ``pprof``-format protocol buffer.
|
||||
|
||||
A device memory profile is a snapshot of the state of memory, that describes the JAX
|
||||
:class:`jax.DeviceArray` and executable objects present in memory and their
|
||||
:class:`~jax.Array` and executable objects present in memory and their
|
||||
allocation sites.
|
||||
|
||||
For more information how to use the device memory profiler, see
|
||||
|
@ -13,12 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
# A ShardingSpec describes at a high level how a logical array is sharded across
|
||||
# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also
|
||||
# describe how to shard inputs to a parallel computation). spec_to_indices()
|
||||
# encodes exactly how a given ShardingSpec is translated to device buffers, i.e.
|
||||
# how the sharded array is "laid out" across devices. Given a sequence of
|
||||
# devices, we shard the data across the devices in row-major order, with
|
||||
# replication treated as an extra inner dimension.
|
||||
# devices (each array sharded with a `PmapSharding` has a ShardingSpec, and
|
||||
# ShardingSpecs also describe how to shard inputs to a parallel computation).
|
||||
# spec_to_indices() encodes exactly how a given ShardingSpec is translated to
|
||||
# device buffers, i.e. how the sharded array is "laid out" across devices. Given
|
||||
# a sequence of devices, we shard the data across the devices in row-major
|
||||
# order, with replication treated as an extra inner dimension.
|
||||
#
|
||||
# For example, given the logical data array [1, 2, 3, 4], if we were to
|
||||
# partition this array 4 ways with a replication factor of 2, for a total of 8
|
||||
@ -233,8 +233,8 @@ def spec_to_indices(shape: Sequence[int],
|
||||
"""Returns numpy-style indices corresponding to a sharding spec.
|
||||
|
||||
Each index describes a shard of the array. The order of the indices is the
|
||||
same as the device_buffers of a ShardedDeviceArray (i.e. the data is laid out
|
||||
row-major).
|
||||
same as the device_buffers of a Array sharded using PmapSharding (i.e. the
|
||||
data is laid out row-major).
|
||||
|
||||
Args:
|
||||
shape: The shape of the logical array being sharded.
|
||||
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""GlobalDeviceArray serialization and deserialization."""
|
||||
"""Array serialization and deserialization."""
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
@ -482,7 +482,7 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas
|
||||
"""Responsible for serializing GDAs via TensorStore."""
|
||||
|
||||
def serialize(self, arrays, tensorstore_specs, *, on_commit_callback):
|
||||
"""Serializes GlobalDeviceArrays or Arrays via TensorStore asynchronously.
|
||||
"""Serializes Arrays or Arrays via TensorStore asynchronously.
|
||||
|
||||
TensorStore writes to a storage layer in 2 steps:
|
||||
* Reading/copying from the source after which the source can be modified.
|
||||
@ -494,7 +494,7 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas
|
||||
finish in a separate thread allowing other computation to proceed.
|
||||
|
||||
Args:
|
||||
arrays: GlobalDeviceArrays or Arrays that should be serialized.
|
||||
arrays: Arrays or Arrays that should be serialized.
|
||||
tensorstore_specs: TensorStore specs that are used to serialize GDAs or
|
||||
Arrays.
|
||||
on_commit_callback: This callback will be executed after all processes
|
||||
|
@ -57,7 +57,7 @@ TfConcreteFunction = Any
|
||||
TfVal = jax2tf_internal.TfVal
|
||||
|
||||
# The platforms for which to use DLPack to avoid copying (only works on GPU
|
||||
# and CPU at the moment, and only for DeviceArray). For CPU we don't need
|
||||
# and CPU at the moment, and only for Array). For CPU we don't need
|
||||
# DLPack, if we are careful.
|
||||
_DLPACK_PLATFORMS = ("gpu",)
|
||||
|
||||
@ -335,7 +335,7 @@ def _call_tf_impl(*args_jax_flat, callable_flat_tf, **_):
|
||||
arg_jax.dtype in dlpack.SUPPORTED_DTYPES):
|
||||
arg_dlpack = jax.dlpack.to_dlpack(arg_jax, take_ownership=False)
|
||||
return tf.experimental.dlpack.from_dlpack(arg_dlpack)
|
||||
# The following avoids copies to the host on CPU, always for DeviceArray
|
||||
# The following avoids copies to the host on CPU, always for Array
|
||||
# and even for ndarray if they are sufficiently aligned.
|
||||
# TODO(necula): on TPU this copies to the host!
|
||||
return tf.constant(np.asarray(arg_jax))
|
||||
|
@ -1814,11 +1814,11 @@ def _not(x):
|
||||
Numpy and JAX support bitwise not for booleans by applying a logical not!
|
||||
This means that applying bitwise_not yields an unexpected result:
|
||||
jnp.bitwise_not(jnp.array([True, False]))
|
||||
>> DeviceArray([False, True], dtype=bool)
|
||||
>> Array([False, True], dtype=bool)
|
||||
|
||||
if you assume that booleans are simply casted to integers.
|
||||
jnp.bitwise_not(jnp.array([True, False]).astype(np.int32)).astype(bool)
|
||||
>> DeviceArray([True, True], dtype=bool)
|
||||
>> Array([True, True], dtype=bool)
|
||||
"""
|
||||
if x.dtype == tf.bool:
|
||||
return tf.logical_not(x)
|
||||
|
@ -189,7 +189,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
f_tf = tf.function(lambda x: x + x)
|
||||
self.assertEqual(f_tf(jnp.ones([])).numpy(), 2.)
|
||||
|
||||
# Test with ShardedDeviceArray.
|
||||
# Test with a PmapSharding-sharded Array.
|
||||
n = jax.local_device_count()
|
||||
mk_sharded = lambda f: jax.pmap(lambda x: x)(f([n]))
|
||||
f_tf = tf.function(lambda x: x)
|
||||
|
@ -226,7 +226,7 @@ class DimExprTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertFalse(core.definitely_equal_one_of_dim(1, [2, b]))
|
||||
self.assertFalse(core.definitely_equal_one_of_dim(3, []))
|
||||
|
||||
self.assertTrue(core.definitely_equal(1, jnp.add(0, 1))) # A DeviceArray
|
||||
self.assertTrue(core.definitely_equal(1, jnp.add(0, 1))) # An Array
|
||||
self.assertFalse(core.definitely_equal(1, "a"))
|
||||
|
||||
def test_poly_bounds(self):
|
||||
|
@ -2505,7 +2505,7 @@ class BCOO(JAXSparse):
|
||||
@classmethod
|
||||
def fromdense(cls, mat: Array, *, nse: int | None = None, index_dtype: DTypeLike = np.int32,
|
||||
n_dense: int = 0, n_batch: int = 0) -> BCOO:
|
||||
"""Create a BCOO array from a (dense) :class:`DeviceArray`."""
|
||||
"""Create a BCOO array from a (dense) :class:`~jax.Array`."""
|
||||
return bcoo_fromdense(
|
||||
mat, nse=nse, index_dtype=index_dtype, n_dense=n_dense, n_batch=n_batch)
|
||||
|
||||
|
@ -848,7 +848,7 @@ class BCSR(JAXSparse):
|
||||
@classmethod
|
||||
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32, n_dense=0,
|
||||
n_batch=0):
|
||||
"""Create a BCSR array from a (dense) :class:`DeviceArray`."""
|
||||
"""Create a BCSR array from a (dense) :class:`Array`."""
|
||||
return bcsr_fromdense(mat, nse=nse, index_dtype=index_dtype,
|
||||
n_dense=n_dense, n_batch=n_batch)
|
||||
|
||||
|
@ -2669,7 +2669,7 @@ class APITest(jtu.JaxTestCase):
|
||||
error_text = "float0s do not support any operations by design"
|
||||
|
||||
with self.assertRaisesRegex(TypeError, error_text):
|
||||
# dispatch via DeviceArray
|
||||
# dispatch via Array
|
||||
_ = float0_array + jnp.zeros(())
|
||||
|
||||
with self.assertRaisesRegex(TypeError, error_text):
|
||||
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for GlobalDeviceArray."""
|
||||
"""Tests for Array."""
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
|
@ -661,7 +661,7 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
for dtype in all_dtypes
|
||||
for weak_type in [True, False]
|
||||
)
|
||||
def testDeviceArrayRepr(self, dtype, weak_type):
|
||||
def testArrayRepr(self, dtype, weak_type):
|
||||
val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
|
||||
rep = repr(val)
|
||||
self.assertStartsWith(rep, 'Array(')
|
||||
|
@ -2567,7 +2567,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
# TODO(mattjj): make the numpy.ndarray test pass w/ remat
|
||||
raise unittest.SkipTest("new-remat-of-scan doesn't convert numpy.ndarray")
|
||||
|
||||
x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not DeviceArray
|
||||
x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not Array
|
||||
_, vjp_fun = jax.vjp(cumprod, x)
|
||||
*_, ext_res = vjp_fun.args[0].args[0]
|
||||
self.assertIsInstance(ext_res, jax.Array)
|
||||
|
@ -3196,7 +3196,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testIssue121(self):
|
||||
assert not np.isscalar(jnp.array(3))
|
||||
|
||||
def testArrayOutputsDeviceArrays(self):
|
||||
def testArrayOutputsArrays(self):
|
||||
assert type(jnp.array([])) is array.ArrayImpl
|
||||
assert type(jnp.array(np.array([]))) is array.ArrayImpl
|
||||
|
||||
@ -3206,10 +3206,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
assert type(jnp.array(NDArrayLike())) is array.ArrayImpl
|
||||
|
||||
# NOTE(mattjj): disabled b/c __array__ must produce ndarrays
|
||||
# class DeviceArrayLike:
|
||||
# class ArrayLike:
|
||||
# def __array__(self, dtype=None):
|
||||
# return jnp.array([], dtype=dtype)
|
||||
# assert xla.type_is_device_array(jnp.array(DeviceArrayLike()))
|
||||
# assert xla.type_is_device_array(jnp.array(ArrayLike()))
|
||||
|
||||
def testArrayMethod(self):
|
||||
class arraylike:
|
||||
|
@ -320,8 +320,6 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
self.assertEqual(y[0], jax.device_count())
|
||||
print(y)
|
||||
|
||||
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
|
||||
# since `GlobalDeviceArray` is going to be deprecated in the future
|
||||
def test_pjit_gda_multi_input_multi_output(self):
|
||||
jax.distributed.initialize()
|
||||
global_mesh = jtu.create_global_mesh((8, 2), ("x", "y"))
|
||||
@ -370,8 +368,6 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
np.testing.assert_array_equal(np.asarray(s.data),
|
||||
global_input_data[s.index])
|
||||
|
||||
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
|
||||
# since `GlobalDeviceArray` is going to be deprecated in the future
|
||||
def test_pjit_gda_non_contiguous_mesh(self):
|
||||
jax.distributed.initialize()
|
||||
devices = self.sorted_devices()
|
||||
@ -428,8 +424,6 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
np.testing.assert_array_equal(np.asarray(s.data),
|
||||
global_input_data[expected_index])
|
||||
|
||||
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
|
||||
# since `GlobalDeviceArray` is going to be deprecated in the future
|
||||
def test_pjit_gda_non_contiguous_mesh_2d(self):
|
||||
jax.distributed.initialize()
|
||||
global_mesh = self.create_2d_non_contiguous_mesh()
|
||||
@ -504,8 +498,6 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
# Fully replicated values + GDA allows a non-contiguous mesh.
|
||||
out1, out2 = f(global_input_data, gda2)
|
||||
|
||||
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
|
||||
# since `GlobalDeviceArray` is going to be deprecated in the future
|
||||
def test_pjit_gda_non_contiguous_mesh_2d_aot(self):
|
||||
jax.distributed.initialize()
|
||||
global_mesh = self.create_2d_non_contiguous_mesh()
|
||||
@ -531,8 +523,6 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out1.shape, (8, 2))
|
||||
self.assertEqual(out2.shape, (8, 2))
|
||||
|
||||
# TODO(sudhakarsingh27): To change/omit test in favor of using `Array`
|
||||
# since `GlobalDeviceArray` is going to be deprecated in the future
|
||||
def test_pjit_gda_eval_shape(self):
|
||||
jax.distributed.initialize()
|
||||
|
||||
|
@ -98,7 +98,7 @@ class CloudpickleTest(jtu.JaxTestCase):
|
||||
|
||||
class PickleTest(jtu.JaxTestCase):
|
||||
|
||||
def testPickleOfDeviceArray(self):
|
||||
def testPickleOfArray(self):
|
||||
x = jnp.arange(10.0)
|
||||
s = pickle.dumps(x)
|
||||
y = pickle.loads(s)
|
||||
@ -106,7 +106,7 @@ class PickleTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(y, type(x))
|
||||
self.assertEqual(x.aval, y.aval)
|
||||
|
||||
def testPickleOfDeviceArrayWeakType(self):
|
||||
def testPickleOfArrayWeakType(self):
|
||||
x = jnp.array(4.0)
|
||||
self.assertEqual(x.aval.weak_type, True)
|
||||
s = pickle.dumps(x)
|
||||
|
@ -842,7 +842,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertNotIsInstance(z, np.ndarray)
|
||||
self.assertAllClose(z, 2 * 2 * x, check_dtypes=False)
|
||||
|
||||
# test that we can pass in a regular DeviceArray
|
||||
# test that we can pass in a regular Array
|
||||
y = f(device_put(x))
|
||||
self.assertIsInstance(y, array.ArrayImpl)
|
||||
self.assertAllClose(y, 2 * x, check_dtypes=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user