mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Update convolutions.md
Update convolutions.ipynb Update convolutions.ipynb Update convolutions.md
This commit is contained in:
parent
e51c98fbe1
commit
8d54a53b53
@ -205,10 +205,10 @@
|
||||
],
|
||||
"source": [
|
||||
"# 2D kernel - HWIO layout\n",
|
||||
"kernel = jnp.zeros((3, 3, 3, 3), dtype=np.float32)\n",
|
||||
"kernel = jnp.zeros((3, 3, 3, 3), dtype=jnp.float32)\n",
|
||||
"kernel += jnp.array([[1, 1, 0],\n",
|
||||
" [1, 0,-1],\n",
|
||||
" [0,-1,-1]])[:, :, np.newaxis, np.newaxis]\n",
|
||||
" [0,-1,-1]])[:, :, jnp.newaxis, jnp.newaxis]\n",
|
||||
"\n",
|
||||
"print(\"Edge Conv kernel:\")\n",
|
||||
"plt.imshow(kernel[:, :, 0, 0]);"
|
||||
@ -804,7 +804,7 @@
|
||||
],
|
||||
"source": [
|
||||
"# 1D kernel - WIO layout\n",
|
||||
"kernel = np.array([[[1, 0, -1], [-1, 0, 1]], \n",
|
||||
"kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]], \n",
|
||||
" [[1, 1, 1], [-1, -1, -1]]], \n",
|
||||
" dtype=jnp.float32).transpose([2,1,0])\n",
|
||||
"# 1D data - NWC layout\n",
|
||||
@ -891,16 +891,16 @@
|
||||
"import matplotlib as mpl\n",
|
||||
"\n",
|
||||
"# Random 3D kernel - HWDIO layout\n",
|
||||
"kernel = np.array([\n",
|
||||
"kernel = jnp.array([\n",
|
||||
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]],\n",
|
||||
" [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], \n",
|
||||
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], \n",
|
||||
" dtype=jnp.float32)[:, :, :, np.newaxis, np.newaxis]\n",
|
||||
" dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis]\n",
|
||||
"\n",
|
||||
"# 3D data - NHWDC layout\n",
|
||||
"data = np.zeros((1, 30, 30, 30, 1), dtype=jnp.float32)\n",
|
||||
"data = jnp.zeros((1, 30, 30, 30, 1), dtype=jnp.float32)\n",
|
||||
"x, y, z = np.mgrid[0:1:30j, 0:1:30j, 0:1:30j]\n",
|
||||
"data += (np.sin(2*x*jnp.pi)*np.cos(2*y*jnp.pi)*np.cos(2*z*jnp.pi))[None,:,:,:,None]\n",
|
||||
"data += (jnp.sin(2*x*jnp.pi)*jnp.cos(2*y*jnp.pi)*jnp.cos(2*z*jnp.pi))[None,:,:,:,None]\n",
|
||||
"\n",
|
||||
"print(\"in shapes:\", data.shape, kernel.shape)\n",
|
||||
"dn = lax.conv_dimension_numbers(data.shape, kernel.shape,\n",
|
||||
|
@ -121,10 +121,10 @@ id: Yud1Y3ss-x1K
|
||||
outputId: 3185fba5-1ad7-462f-96ba-7ed1b0c3d5a2
|
||||
---
|
||||
# 2D kernel - HWIO layout
|
||||
kernel = jnp.zeros((3, 3, 3, 3), dtype=np.float32)
|
||||
kernel = jnp.zeros((3, 3, 3, 3), dtype=jnp.float32)
|
||||
kernel += jnp.array([[1, 1, 0],
|
||||
[1, 0,-1],
|
||||
[0,-1,-1]])[:, :, np.newaxis, np.newaxis]
|
||||
[0,-1,-1]])[:, :, jnp.newaxis, jnp.newaxis]
|
||||
|
||||
print("Edge Conv kernel:")
|
||||
plt.imshow(kernel[:, :, 0, 0]);
|
||||
@ -375,7 +375,7 @@ id: jJ-jcAn3cig-
|
||||
outputId: 67c46ace-6adc-4c47-c1c7-1f185be5fd4b
|
||||
---
|
||||
# 1D kernel - WIO layout
|
||||
kernel = np.array([[[1, 0, -1], [-1, 0, 1]],
|
||||
kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]],
|
||||
[[1, 1, 1], [-1, -1, -1]]],
|
||||
dtype=jnp.float32).transpose([2,1,0])
|
||||
# 1D data - NWC layout
|
||||
@ -417,16 +417,16 @@ outputId: c99ec88c-6d5c-4acd-c8d3-331f026f5631
|
||||
import matplotlib as mpl
|
||||
|
||||
# Random 3D kernel - HWDIO layout
|
||||
kernel = np.array([
|
||||
kernel = jnp.array([
|
||||
[[0, 0, 0], [0, 1, 0], [0, 0, 0]],
|
||||
[[0, -1, 0], [-1, 0, -1], [0, -1, 0]],
|
||||
[[0, 0, 0], [0, 1, 0], [0, 0, 0]]],
|
||||
dtype=jnp.float32)[:, :, :, np.newaxis, np.newaxis]
|
||||
dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis]
|
||||
|
||||
# 3D data - NHWDC layout
|
||||
data = np.zeros((1, 30, 30, 30, 1), dtype=jnp.float32)
|
||||
data = jnp.zeros((1, 30, 30, 30, 1), dtype=jnp.float32)
|
||||
x, y, z = np.mgrid[0:1:30j, 0:1:30j, 0:1:30j]
|
||||
data += (np.sin(2*x*jnp.pi)*np.cos(2*y*jnp.pi)*np.cos(2*z*jnp.pi))[None,:,:,:,None]
|
||||
data += (jnp.sin(2*x*jnp.pi)*jnp.cos(2*y*jnp.pi)*jnp.cos(2*z*jnp.pi))[None,:,:,:,None]
|
||||
|
||||
print("in shapes:", data.shape, kernel.shape)
|
||||
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
|
||||
|
Loading…
x
Reference in New Issue
Block a user