add 'open in colab' button, add numpy<1.18 compat

Co-authored-by: Edward Loper <edloper@google.com>
This commit is contained in:
Matthew Johnson 2021-04-09 18:10:20 -07:00
parent b4a8261b3e
commit 5f6bce4bfe
3 changed files with 21 additions and 10 deletions

View File

@ -23,13 +23,14 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {
"lines_to_next_cell": 0
"lines_to_next_cell": 2
},
"outputs": [],
"source": []
"source": [
"[![Open in\n",
"Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/autodidax.ipynb)"
]
},
{
"cell_type": "markdown",
@ -538,7 +539,9 @@
"impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]\n",
"\n",
"def broadcast_impl(x, *, shape, axes):\n",
" return [np.broadcast_to(np.expand_dims(x, axes), shape)]\n",
" for axis in sorted(axes):\n",
" x = np.expand_dims(x, axis)\n",
" return [np.broadcast_to(x, shape)]\n",
"impl_rules[broadcast_p] = broadcast_impl"
]
},

View File

@ -32,9 +32,10 @@ limitations under the License.
---
```
```{code-cell}
[![Open in
Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/autodidax.ipynb)
```
+++
# Autodidax: JAX core from scratch
@ -423,7 +424,9 @@ impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]
def broadcast_impl(x, *, shape, axes):
return [np.broadcast_to(np.expand_dims(x, axes), shape)]
for axis in sorted(axes):
x = np.expand_dims(x, axis)
return [np.broadcast_to(x, shape)]
impl_rules[broadcast_p] = broadcast_impl
```

View File

@ -26,6 +26,9 @@
# name: python3
# ---
# [![Open in
# Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/master/docs/autodidax.ipynb)
# # Autodidax: JAX core from scratch
#
@ -404,7 +407,9 @@ impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]
def broadcast_impl(x, *, shape, axes):
return [np.broadcast_to(np.expand_dims(x, axes), shape)]
for axis in sorted(axes):
x = np.expand_dims(x, axis)
return [np.broadcast_to(x, shape)]
impl_rules[broadcast_p] = broadcast_impl
# -