Refer to the original map/zip classes via builtins

Referring to them as simply `map` or `zip` will create recursive
reimplementations (with no base case!) if the cell is reevaluated in
the same runtime.
This commit is contained in:
Kevin Millikin 2023-05-24 07:47:50 +01:00
parent d9e7a2abf8
commit 921fd222bf
3 changed files with 12 additions and 14 deletions

View File

@ -570,6 +570,8 @@
"metadata": {},
"outputs": [],
"source": [
"import builtins\n",
"\n",
"def zeros_like(val):\n",
" aval = get_aval(val)\n",
" return np.zeros(aval.shape, aval.dtype)\n",
@ -581,17 +583,15 @@
" lst2.append(x2)\n",
" return lst1, lst2\n",
"\n",
"map_ = map\n",
"def map(f, *xs):\n",
" return list(map_(f, *xs))\n",
" return list(builtins.map(f, *xs))\n",
"\n",
"zip_ = zip\n",
"def zip(*args):\n",
" fst, *rest = args = map(list, args)\n",
" n = len(fst)\n",
" for arg in rest:\n",
" assert len(arg) == n\n",
" return list(zip_(*args))"
" return list(builtins.zip(*args))"
]
},
{

View File

@ -460,6 +460,8 @@ that now we can add some real transformations.
First, a few helper functions:
```{code-cell}
import builtins
def zeros_like(val):
aval = get_aval(val)
return np.zeros(aval.shape, aval.dtype)
@ -471,17 +473,15 @@ def unzip2(pairs):
lst2.append(x2)
return lst1, lst2
map_ = map
def map(f, *xs):
return list(map_(f, *xs))
return list(builtins.map(f, *xs))
zip_ = zip
def zip(*args):
fst, *rest = args = map(list, args)
n = len(fst)
for arg in rest:
assert len(arg) == n
return list(zip_(*args))
return list(builtins.zip(*args))
```
The `Tracer` for forward-mode autodiff carries a primal-tangent pair. The

View File

@ -444,8 +444,6 @@ def f(x):
return z
print(f(3.0))
# -
# Woo! Like going around in a big circle. But the point of this indirection is
@ -456,6 +454,8 @@ print(f(3.0))
# First, a few helper functions:
# +
import builtins
def zeros_like(val):
aval = get_aval(val)
return np.zeros(aval.shape, aval.dtype)
@ -467,17 +467,15 @@ def unzip2(pairs):
lst2.append(x2)
return lst1, lst2
map_ = map
def map(f, *xs):
return list(map_(f, *xs))
return list(builtins.map(f, *xs))
zip_ = zip
def zip(*args):
fst, *rest = args = map(list, args)
n = len(fst)
for arg in rest:
assert len(arg) == n
return list(zip_(*args))
return list(builtins.zip(*args))
# -