mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Refer to the original map
/zip
classes via builtins
Referring to them as simply `map` or `zip` will create recursive reimplementations (with no base case!) if the cell is reevaluated in the same runtime.
This commit is contained in:
parent
d9e7a2abf8
commit
921fd222bf
@ -570,6 +570,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import builtins\n",
|
||||
"\n",
|
||||
"def zeros_like(val):\n",
|
||||
" aval = get_aval(val)\n",
|
||||
" return np.zeros(aval.shape, aval.dtype)\n",
|
||||
@ -581,17 +583,15 @@
|
||||
" lst2.append(x2)\n",
|
||||
" return lst1, lst2\n",
|
||||
"\n",
|
||||
"map_ = map\n",
|
||||
"def map(f, *xs):\n",
|
||||
" return list(map_(f, *xs))\n",
|
||||
" return list(builtins.map(f, *xs))\n",
|
||||
"\n",
|
||||
"zip_ = zip\n",
|
||||
"def zip(*args):\n",
|
||||
" fst, *rest = args = map(list, args)\n",
|
||||
" n = len(fst)\n",
|
||||
" for arg in rest:\n",
|
||||
" assert len(arg) == n\n",
|
||||
" return list(zip_(*args))"
|
||||
" return list(builtins.zip(*args))"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -460,6 +460,8 @@ that now we can add some real transformations.
|
||||
First, a few helper functions:
|
||||
|
||||
```{code-cell}
|
||||
import builtins
|
||||
|
||||
def zeros_like(val):
|
||||
aval = get_aval(val)
|
||||
return np.zeros(aval.shape, aval.dtype)
|
||||
@ -471,17 +473,15 @@ def unzip2(pairs):
|
||||
lst2.append(x2)
|
||||
return lst1, lst2
|
||||
|
||||
map_ = map
|
||||
def map(f, *xs):
|
||||
return list(map_(f, *xs))
|
||||
return list(builtins.map(f, *xs))
|
||||
|
||||
zip_ = zip
|
||||
def zip(*args):
|
||||
fst, *rest = args = map(list, args)
|
||||
n = len(fst)
|
||||
for arg in rest:
|
||||
assert len(arg) == n
|
||||
return list(zip_(*args))
|
||||
return list(builtins.zip(*args))
|
||||
```
|
||||
|
||||
The `Tracer` for forward-mode autodiff carries a primal-tangent pair. The
|
||||
|
@ -444,8 +444,6 @@ def f(x):
|
||||
return z
|
||||
|
||||
print(f(3.0))
|
||||
|
||||
|
||||
# -
|
||||
|
||||
# Woo! Like going around in a big circle. But the point of this indirection is
|
||||
@ -456,6 +454,8 @@ print(f(3.0))
|
||||
# First, a few helper functions:
|
||||
|
||||
# +
|
||||
import builtins
|
||||
|
||||
def zeros_like(val):
|
||||
aval = get_aval(val)
|
||||
return np.zeros(aval.shape, aval.dtype)
|
||||
@ -467,17 +467,15 @@ def unzip2(pairs):
|
||||
lst2.append(x2)
|
||||
return lst1, lst2
|
||||
|
||||
map_ = map
|
||||
def map(f, *xs):
|
||||
return list(map_(f, *xs))
|
||||
return list(builtins.map(f, *xs))
|
||||
|
||||
zip_ = zip
|
||||
def zip(*args):
|
||||
fst, *rest = args = map(list, args)
|
||||
n = len(fst)
|
||||
for arg in rest:
|
||||
assert len(arg) == n
|
||||
return list(zip_(*args))
|
||||
return list(builtins.zip(*args))
|
||||
|
||||
|
||||
# -
|
||||
|
Loading…
x
Reference in New Issue
Block a user