mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
add 'open in colab' button, add numpy<1.18 compat
Co-authored-by: Edward Loper <edloper@google.com>
This commit is contained in:
parent
b4a8261b3e
commit
5f6bce4bfe
@ -23,13 +23,14 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 0
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"[](https://colab.research.google.com/github/google/jax/blob/master/docs/autodidax.ipynb)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@ -538,7 +539,9 @@
|
||||
"impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]\n",
|
||||
"\n",
|
||||
"def broadcast_impl(x, *, shape, axes):\n",
|
||||
" return [np.broadcast_to(np.expand_dims(x, axes), shape)]\n",
|
||||
" for axis in sorted(axes):\n",
|
||||
" x = np.expand_dims(x, axis)\n",
|
||||
" return [np.broadcast_to(x, shape)]\n",
|
||||
"impl_rules[broadcast_p] = broadcast_impl"
|
||||
]
|
||||
},
|
||||
|
@ -32,9 +32,10 @@ limitations under the License.
|
||||
---
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
[](https://colab.research.google.com/github/google/jax/blob/master/docs/autodidax.ipynb)
|
||||
|
||||
```
|
||||
+++
|
||||
|
||||
# Autodidax: JAX core from scratch
|
||||
|
||||
@ -423,7 +424,9 @@ impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
|
||||
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]
|
||||
|
||||
def broadcast_impl(x, *, shape, axes):
|
||||
return [np.broadcast_to(np.expand_dims(x, axes), shape)]
|
||||
for axis in sorted(axes):
|
||||
x = np.expand_dims(x, axis)
|
||||
return [np.broadcast_to(x, shape)]
|
||||
impl_rules[broadcast_p] = broadcast_impl
|
||||
```
|
||||
|
||||
|
@ -26,6 +26,9 @@
|
||||
# name: python3
|
||||
# ---
|
||||
|
||||
# [](https://colab.research.google.com/github/google/jax/blob/master/docs/autodidax.ipynb)
|
||||
|
||||
|
||||
# # Autodidax: JAX core from scratch
|
||||
#
|
||||
@ -404,7 +407,9 @@ impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
|
||||
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]
|
||||
|
||||
def broadcast_impl(x, *, shape, axes):
|
||||
return [np.broadcast_to(np.expand_dims(x, axes), shape)]
|
||||
for axis in sorted(axes):
|
||||
x = np.expand_dims(x, axis)
|
||||
return [np.broadcast_to(x, shape)]
|
||||
impl_rules[broadcast_p] = broadcast_impl
|
||||
# -
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user