Make colab_gpu.ipynb compatible with newer JAX versions

PiperOrigin-RevId: 509356393
This commit is contained in:
Jake VanderPlas 2023-02-13 15:56:09 -08:00 committed by jax authors
parent d0eedf7e57
commit e1ff0c1d7a

View File

@ -90,13 +90,12 @@
}
},
"source": [
"from jaxlib import xla_extension\n",
"import jax\n",
"key = jax.random.PRNGKey(1701)\n",
"arr = jax.random.normal(key, (1000,))\n",
"device = arr.device_buffer.device()\n",
"device = arr.device()\n",
"print(f\"JAX device type: {device}\")\n",
"assert isinstance(device, xla_extension.GpuDevice), \"unexpected JAX device type\""
"assert device.platform == \"gpu\", \"unexpected JAX device type\""
],
"execution_count": 2,
"outputs": [