mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Make colab_gpu.ipynb compatible with newer JAX versions
PiperOrigin-RevId: 509356393
This commit is contained in:
parent
d0eedf7e57
commit
e1ff0c1d7a
@ -90,13 +90,12 @@
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"from jaxlib import xla_extension\n",
|
||||
"import jax\n",
|
||||
"key = jax.random.PRNGKey(1701)\n",
|
||||
"arr = jax.random.normal(key, (1000,))\n",
|
||||
"device = arr.device_buffer.device()\n",
|
||||
"device = arr.device()\n",
|
||||
"print(f\"JAX device type: {device}\")\n",
|
||||
"assert isinstance(device, xla_extension.GpuDevice), \"unexpected JAX device type\""
|
||||
"assert device.platform == \"gpu\", \"unexpected JAX device type\""
|
||||
],
|
||||
"execution_count": 2,
|
||||
"outputs": [
|
||||
|
Loading…
x
Reference in New Issue
Block a user