mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
DOC: add section to Sharp Bits discussing implicit list conversions
This commit is contained in:
parent
62603fde67
commit
d844609c6d
@ -568,6 +568,209 @@
|
||||
"Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "LwB07Kx5sgHu"
|
||||
},
|
||||
"source": [
|
||||
"## 🔪 Non-array inputs: NumPy vs. JAX\n",
|
||||
"\n",
|
||||
"NumPy is generally happy accepting Python lists or tuples as inputs to its API functions:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"id": "sErQES14sjCG",
|
||||
"outputId": "6bc29168-624a-4d51-eef1-220aeaf49985"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"6"
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"np.sum([1, 2, 3])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ZJ1Wt1bTtrSA"
|
||||
},
|
||||
"source": [
|
||||
"JAX departs from this, generally returning a helpful error:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"metadata": {
|
||||
"id": "DFEGcENSsmEc",
|
||||
"outputId": "86105261-0aec-41e0-c8a6-16eec437e2a8"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"try:\n",
|
||||
" jnp.sum([1, 2, 3])\n",
|
||||
"except TypeError as e:\n",
|
||||
" print(f\"TypeError: {e}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "QPliLUZztxJt"
|
||||
},
|
||||
"source": [
|
||||
"This is a deliberate design choice, because passing lists or tuples to traced functions can lead to silent performance degradation that might otherwise be difficult to detect.\n",
|
||||
"\n",
|
||||
"For example, consider the following permissive version of `jnp.sum` that allows list inputs:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"id": "jhe-L_TwsvKd",
|
||||
"outputId": "24ef84d4-79e5-42de-f8d4-34e6701c2576"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(45, dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def permissive_sum(x):\n",
|
||||
" return jnp.sum(jnp.array(x))\n",
|
||||
"\n",
|
||||
"x = list(range(10))\n",
|
||||
"permissive_sum(x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "m0XZLP7nuYdE"
|
||||
},
|
||||
"source": [
|
||||
"The output is what we would expect, but this hides potential performance issues under the hood. In JAX's tracing and JIT compilation model, each element in a Python list or tuple is treated as a separate JAX variable, and individually processed and pushed to device. This can be seen in the jaxpr for the ``permissive_sum`` function above:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {
|
||||
"id": "k81u6DQ7vAjQ",
|
||||
"outputId": "52847378-ba8c-4e84-fb8b-dabbaded6a00"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{ lambda ; a b c d e f g h i j.\n",
|
||||
" let k = broadcast_in_dim[ broadcast_dimensions=( )\n",
|
||||
" shape=(1,) ] a\n",
|
||||
" l = broadcast_in_dim[ broadcast_dimensions=( )\n",
|
||||
" shape=(1,) ] b\n",
|
||||
" m = broadcast_in_dim[ broadcast_dimensions=( )\n",
|
||||
" shape=(1,) ] c\n",
|
||||
" n = broadcast_in_dim[ broadcast_dimensions=( )\n",
|
||||
" shape=(1,) ] d\n",
|
||||
" o = broadcast_in_dim[ broadcast_dimensions=( )\n",
|
||||
" shape=(1,) ] e\n",
|
||||
" p = broadcast_in_dim[ broadcast_dimensions=( )\n",
|
||||
" shape=(1,) ] f\n",
|
||||
" q = broadcast_in_dim[ broadcast_dimensions=( )\n",
|
||||
" shape=(1,) ] g\n",
|
||||
" r = broadcast_in_dim[ broadcast_dimensions=( )\n",
|
||||
" shape=(1,) ] h\n",
|
||||
" s = broadcast_in_dim[ broadcast_dimensions=( )\n",
|
||||
" shape=(1,) ] i\n",
|
||||
" t = broadcast_in_dim[ broadcast_dimensions=( )\n",
|
||||
" shape=(1,) ] j\n",
|
||||
" u = concatenate[ dimension=0 ] k l m n o p q r s t\n",
|
||||
" v = convert_element_type[ new_dtype=int32\n",
|
||||
" weak_type=False ] u\n",
|
||||
" w = reduce_sum[ axes=(0,) ] v\n",
|
||||
" in (w,) }"
|
||||
]
|
||||
},
|
||||
"execution_count": 31,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"make_jaxpr(permissive_sum)(x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "C0_dpCfpvCts"
|
||||
},
|
||||
"source": [
|
||||
"Each entry of the list is handled as a separate input, resulting in a tracing & compilation overhead that grows linearly with the size of the list. To prevent surprises like this, JAX avoids implicit conversions of lists and tuples to arrays.\n",
|
||||
"\n",
|
||||
"If you would like to pass a tuple or list to a JAX function, you can do so by first explicitly converting it to an array:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"id": "nFf_DydixG8v",
|
||||
"outputId": "5e4392b6-37eb-4a24-ce4f-43518e61d9b1"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"DeviceArray(45, dtype=int32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"jnp.sum(jnp.array(x))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
|
@ -55,10 +55,9 @@ JAX transformation and compilation are designed to work only on Python functions
|
||||
Here are some examples of functions that are not functially pure for which JAX behaves differently than the Python interpreter. Note that these behaviors are not guaranteed by the JAX system; the proper way to use JAX is to use it only on functionally pure Python functions.
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: A6R-pdcm4u3v
|
||||
outputId: 389605df-a4d5-4d4b-8d74-64e9d5d39456
|
||||
---
|
||||
:id: A6R-pdcm4u3v
|
||||
:outputId: 389605df-a4d5-4d4b-8d74-64e9d5d39456
|
||||
|
||||
def impure_print_side_effect(x):
|
||||
print("Executing function") # This is a side-effect
|
||||
return x
|
||||
@ -75,10 +74,9 @@ print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: -N8GhitI2bhD
|
||||
outputId: f16ce914-1387-43b4-9b8a-1d6e3b97b11d
|
||||
---
|
||||
:id: -N8GhitI2bhD
|
||||
:outputId: f16ce914-1387-43b4-9b8a-1d6e3b97b11d
|
||||
|
||||
g = 0.
|
||||
def impure_uses_globals(x):
|
||||
return x + g
|
||||
@ -96,10 +94,9 @@ print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.]))
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: RTB6iFgu4DL6
|
||||
outputId: e93d2a70-1c18-477a-d69d-d09ed556305a
|
||||
---
|
||||
:id: RTB6iFgu4DL6
|
||||
:outputId: e93d2a70-1c18-477a-d69d-d09ed556305a
|
||||
|
||||
g = 0.
|
||||
def impure_saves_global(x):
|
||||
global g
|
||||
@ -116,10 +113,9 @@ print ("Saved global: ", g) # Saved global has an internal JAX value
|
||||
A Python function can be functionally pure even if it actually uses stateful objects internally, as long as it does not read or write external state:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: TP-Mqf_862C0
|
||||
outputId: 78df2d95-2c6f-41c9-84a9-feda6329e75e
|
||||
---
|
||||
:id: TP-Mqf_862C0
|
||||
:outputId: 78df2d95-2c6f-41c9-84a9-feda6329e75e
|
||||
|
||||
def pure_uses_internal_state(x):
|
||||
state = dict(even=0, odd=0)
|
||||
for i in range(10):
|
||||
@ -168,10 +164,9 @@ iter_operand = iter(range(10))
|
||||
In Numpy you're used to doing this:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: om4xV7_84N9j
|
||||
outputId: 733f901e-d433-4dc8-b5bb-0c23bf2b1306
|
||||
---
|
||||
:id: om4xV7_84N9j
|
||||
:outputId: 733f901e-d433-4dc8-b5bb-0c23bf2b1306
|
||||
|
||||
numpy_array = np.zeros((3,3), dtype=np.float32)
|
||||
print("original array:")
|
||||
print(numpy_array)
|
||||
@ -187,11 +182,10 @@ print(numpy_array)
|
||||
If we try to update a JAX device array in-place, however, we get an __error__! (☉_☉)
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: 2AxeCufq4wAp
|
||||
outputId: d5d873db-cee0-49dc-981d-ec852347f7ca
|
||||
tags: [raises-exception]
|
||||
---
|
||||
:id: 2AxeCufq4wAp
|
||||
:outputId: d5d873db-cee0-49dc-981d-ec852347f7ca
|
||||
:tags: [raises-exception]
|
||||
|
||||
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
|
||||
|
||||
# In place update of JAX's array will yield an error!
|
||||
@ -226,10 +220,9 @@ from jax.ops import index, index_add, index_update
|
||||
If the __input values__ of __index_update__ aren't reused, __jit__-compiled code will perform these operations _in-place_.
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: ygUJT49b7BBk
|
||||
outputId: 1a3511c4-a480-472f-cccb-5e01620cbe99
|
||||
---
|
||||
:id: ygUJT49b7BBk
|
||||
:outputId: 1a3511c4-a480-472f-cccb-5e01620cbe99
|
||||
|
||||
jax_array = jnp.zeros((3, 3))
|
||||
print("original array:")
|
||||
print(jax_array)
|
||||
@ -252,10 +245,9 @@ print(new_jax_array)
|
||||
If the __input values__ of __index_update__ aren't reused, __jit__-compiled code will perform these operations _in-place_.
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: tsw2svao8FUp
|
||||
outputId: 874acd15-a493-4d63-efe4-9f440d5d2a12
|
||||
---
|
||||
:id: tsw2svao8FUp
|
||||
:outputId: 874acd15-a493-4d63-efe4-9f440d5d2a12
|
||||
|
||||
print("original array:")
|
||||
jax_array = jnp.ones((5, 6))
|
||||
print(jax_array)
|
||||
@ -274,11 +266,10 @@ print(new_jax_array)
|
||||
In Numpy, you are used to errors being thrown when you index an array outside of its bounds, like this:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: 5_ZM-BJUypdO
|
||||
outputId: 461f38cd-9452-4bcc-a44f-a07ddfa12f42
|
||||
tags: [raises-exception]
|
||||
---
|
||||
:id: 5_ZM-BJUypdO
|
||||
:outputId: 461f38cd-9452-4bcc-a44f-a07ddfa12f42
|
||||
:tags: [raises-exception]
|
||||
|
||||
try:
|
||||
np.arange(10)[11]
|
||||
except Exception as e:
|
||||
@ -290,10 +281,9 @@ except Exception as e:
|
||||
However, raising an error from code running on an accelerator can be difficult or impossible. Therefore, JAX must choose some non-error behavior for out of bounds indexing (akin to how invalid floating point arithmetic results in `NaN`). When the indexing operation is an array index update (e.g. `index_add` or `scatter`-like primitives), updates at out-of-bounds indices will be skipped; when the operation is an array index retrieval (e.g. NumPy indexing or `gather`-like primitives) the index is clamped to the bounds of the array since __something__ must be returned. For example, the last value of the array will be returned from this indexing operation:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: cusaAD0NypdR
|
||||
outputId: 48428ad6-6cde-43ad-c12d-2eb9b9fe59cf
|
||||
---
|
||||
:id: cusaAD0NypdR
|
||||
:outputId: 48428ad6-6cde-43ad-c12d-2eb9b9fe59cf
|
||||
|
||||
jnp.arange(10)[11]
|
||||
```
|
||||
|
||||
@ -301,6 +291,79 @@ Note that due to this behavior for index retrieval, functions like `jnp.nanargmi
|
||||
|
||||
Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior).
|
||||
|
||||
+++ {"id": "LwB07Kx5sgHu"}
|
||||
|
||||
## 🔪 Non-array inputs: NumPy vs. JAX
|
||||
|
||||
NumPy is generally happy accepting Python lists or tuples as inputs to its API functions:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: sErQES14sjCG
|
||||
outputId: 6bc29168-624a-4d51-eef1-220aeaf49985
|
||||
---
|
||||
np.sum([1, 2, 3])
|
||||
```
|
||||
|
||||
+++ {"id": "ZJ1Wt1bTtrSA"}
|
||||
|
||||
JAX departs from this, generally returning a helpful error:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: DFEGcENSsmEc
|
||||
outputId: 86105261-0aec-41e0-c8a6-16eec437e2a8
|
||||
---
|
||||
try:
|
||||
jnp.sum([1, 2, 3])
|
||||
except TypeError as e:
|
||||
print(f"TypeError: {e}")
|
||||
```
|
||||
|
||||
+++ {"id": "QPliLUZztxJt"}
|
||||
|
||||
This is a deliberate design choice, because passing lists or tuples to traced functions can lead to silent performance degradation that might otherwise be difficult to detect.
|
||||
|
||||
For example, consider the following permissive version of `jnp.sum` that allows list inputs:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: jhe-L_TwsvKd
|
||||
outputId: 24ef84d4-79e5-42de-f8d4-34e6701c2576
|
||||
---
|
||||
def permissive_sum(x):
|
||||
return jnp.sum(jnp.array(x))
|
||||
|
||||
x = list(range(10))
|
||||
permissive_sum(x)
|
||||
```
|
||||
|
||||
+++ {"id": "m0XZLP7nuYdE"}
|
||||
|
||||
The output is what we would expect, but this hides potential performance issues under the hood. In JAX's tracing and JIT compilation model, each element in a Python list or tuple is treated as a separate JAX variable, and individually processed and pushed to device. This can be seen in the jaxpr for the ``permissive_sum`` function above:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: k81u6DQ7vAjQ
|
||||
outputId: 52847378-ba8c-4e84-fb8b-dabbaded6a00
|
||||
---
|
||||
make_jaxpr(permissive_sum)(x)
|
||||
```
|
||||
|
||||
+++ {"id": "C0_dpCfpvCts"}
|
||||
|
||||
Each entry of the list is handled as a separate input, resulting in a tracing & compilation overhead that grows linearly with the size of the list. To prevent surprises like this, JAX avoids implicit conversions of lists and tuples to arrays.
|
||||
|
||||
If you would like to pass a tuple or list to a JAX function, you can do so by first explicitly converting it to an array:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: nFf_DydixG8v
|
||||
outputId: 5e4392b6-37eb-4a24-ce4f-43518e61d9b1
|
||||
---
|
||||
jnp.sum(jnp.array(x))
|
||||
```
|
||||
|
||||
+++ {"id": "MUycRNh6e50W"}
|
||||
|
||||
## 🔪 Random Numbers
|
||||
@ -317,10 +380,9 @@ Note also that, as the two behaviors described above are not inverses of each ot
|
||||
You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: rr9FeP41fynt
|
||||
outputId: 849d84cf-04ad-4e8b-9505-a92f6c0d7a39
|
||||
---
|
||||
:id: rr9FeP41fynt
|
||||
:outputId: 849d84cf-04ad-4e8b-9505-a92f6c0d7a39
|
||||
|
||||
print(np.random.random())
|
||||
print(np.random.random())
|
||||
print(np.random.random())
|
||||
@ -387,10 +449,9 @@ JAX instead implements an _explicit_ PRNG where entropy production and consumpti
|
||||
The random state is described by two unsigned-int32s that we call a __key__:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: yPHE7KTWgAWs
|
||||
outputId: 329e7757-2461-434c-a08c-fde80a2d10c9
|
||||
---
|
||||
:id: yPHE7KTWgAWs
|
||||
:outputId: 329e7757-2461-434c-a08c-fde80a2d10c9
|
||||
|
||||
from jax import random
|
||||
key = random.PRNGKey(0)
|
||||
key
|
||||
@ -403,10 +464,9 @@ JAX's random functions produce pseudorandom numbers from the PRNG state, but __d
|
||||
Reusing the same state will cause __sadness__ and __monotony__, depriving the enduser of __lifegiving chaos__:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: 7zUdQMynoE5e
|
||||
outputId: 50617324-b887-42f2-a7ff-2a10f92d876a
|
||||
---
|
||||
:id: 7zUdQMynoE5e
|
||||
:outputId: 50617324-b887-42f2-a7ff-2a10f92d876a
|
||||
|
||||
print(random.normal(key, shape=(1,)))
|
||||
print(key)
|
||||
# No no no!
|
||||
@ -419,10 +479,9 @@ print(key)
|
||||
Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: ASj0_rSzqgGh
|
||||
outputId: bcc2ed60-2e41-4ef8-e84f-c724654aa198
|
||||
---
|
||||
:id: ASj0_rSzqgGh
|
||||
:outputId: bcc2ed60-2e41-4ef8-e84f-c724654aa198
|
||||
|
||||
print("old key", key)
|
||||
key, subkey = random.split(key)
|
||||
normal_pseudorandom = random.normal(subkey, shape=(1,))
|
||||
@ -435,10 +494,9 @@ print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
|
||||
We propagate the __key__ and make new __subkeys__ whenever we need a new random number:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: jbC34XLor2Ek
|
||||
outputId: 6834a812-7160-4646-ee19-a246f683905a
|
||||
---
|
||||
:id: jbC34XLor2Ek
|
||||
:outputId: 6834a812-7160-4646-ee19-a246f683905a
|
||||
|
||||
print("old key", key)
|
||||
key, subkey = random.split(key)
|
||||
normal_pseudorandom = random.normal(subkey, shape=(1,))
|
||||
@ -451,10 +509,9 @@ print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
|
||||
We can generate more than one __subkey__ at a time:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: lEi08PJ4tfkX
|
||||
outputId: 3bb513de-8d14-4d37-ae57-51d6f5eaa762
|
||||
---
|
||||
:id: lEi08PJ4tfkX
|
||||
:outputId: 3bb513de-8d14-4d37-ae57-51d6f5eaa762
|
||||
|
||||
key, *subkeys = random.split(key, 4)
|
||||
for subkey in subkeys:
|
||||
print(random.normal(subkey, shape=(1,)))
|
||||
@ -471,10 +528,9 @@ for subkey in subkeys:
|
||||
If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager).
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: aAx0T3F8lLtu
|
||||
outputId: 808cfa77-d924-4586-af19-35a8fd7d2238
|
||||
---
|
||||
:id: aAx0T3F8lLtu
|
||||
:outputId: 808cfa77-d924-4586-af19-35a8fd7d2238
|
||||
|
||||
def f(x):
|
||||
if x < 3:
|
||||
return 3. * x ** 2
|
||||
@ -494,10 +550,9 @@ Using control flow with `jit` is more complicated, and by default it has more co
|
||||
This works:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: OZ_BJX0CplNC
|
||||
outputId: 48ce004c-536a-44f5-b020-9267825e7e4d
|
||||
---
|
||||
:id: OZ_BJX0CplNC
|
||||
:outputId: 48ce004c-536a-44f5-b020-9267825e7e4d
|
||||
|
||||
@jit
|
||||
def f(x):
|
||||
for i in range(3):
|
||||
@ -512,10 +567,9 @@ print(f(3))
|
||||
So does this:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: pinVnmRWp6w6
|
||||
outputId: e3e6f2f7-ba59-4a98-cdfc-905c91b38ed1
|
||||
---
|
||||
:id: pinVnmRWp6w6
|
||||
:outputId: e3e6f2f7-ba59-4a98-cdfc-905c91b38ed1
|
||||
|
||||
@jit
|
||||
def g(x):
|
||||
y = 0.
|
||||
@ -531,10 +585,9 @@ print(g(jnp.array([1., 2., 3.])))
|
||||
But this doesn't, at least by default:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: 9z38AIKclRNM
|
||||
outputId: 466730dd-df8b-4b80-ac5e-e55b5ea85ec7
|
||||
---
|
||||
:id: 9z38AIKclRNM
|
||||
:outputId: 466730dd-df8b-4b80-ac5e-e55b5ea85ec7
|
||||
|
||||
@jit
|
||||
def f(x):
|
||||
if x < 3:
|
||||
@ -566,10 +619,9 @@ But there's a tradeoff here: if we trace a Python function on a `ShapedArray((),
|
||||
The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: -Tzp0H7Bt1Sn
|
||||
outputId: aba57a88-d8eb-40b0-ff22-7c266d892b13
|
||||
---
|
||||
:id: -Tzp0H7Bt1Sn
|
||||
:outputId: aba57a88-d8eb-40b0-ff22-7c266d892b13
|
||||
|
||||
def f(x):
|
||||
if x < 3:
|
||||
return 3. * x ** 2
|
||||
@ -586,10 +638,9 @@ print(f(2.))
|
||||
Here's another example, this time involving a loop:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: iwY86_JKvD6b
|
||||
outputId: 1ec847ea-df2b-438d-c0a1-fabf7b93b73d
|
||||
---
|
||||
:id: iwY86_JKvD6b
|
||||
:outputId: 1ec847ea-df2b-438d-c0a1-fabf7b93b73d
|
||||
|
||||
def f(x, n):
|
||||
y = 0.
|
||||
for i in range(n):
|
||||
@ -612,10 +663,9 @@ In effect, the loop gets statically unrolled. JAX can also trace at _higher_ le
|
||||
These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`.
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: Tqe9uLmUI_Gv
|
||||
outputId: fe319758-9959-434c-ab9d-0926e599dbc0
|
||||
---
|
||||
:id: Tqe9uLmUI_Gv
|
||||
:outputId: fe319758-9959-434c-ab9d-0926e599dbc0
|
||||
|
||||
def example_fun(length, val):
|
||||
return jnp.ones((length,)) * val
|
||||
# un-jit'd works fine
|
||||
@ -642,10 +692,9 @@ print(good_example_jit(5, 4))
|
||||
Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: m2ABpRd8K094
|
||||
outputId: 64da37a0-aa06-46a3-e975-88c676c5b9fa
|
||||
---
|
||||
:id: m2ABpRd8K094
|
||||
:outputId: 64da37a0-aa06-46a3-e975-88c676c5b9fa
|
||||
|
||||
@jit
|
||||
def f(x):
|
||||
print(x)
|
||||
@ -680,10 +729,9 @@ def cond(pred, true_fun, false_fun, operand):
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: SGxz9JOWeiyH
|
||||
outputId: b29da06c-037f-4b05-dbd8-ba52ac35a8cf
|
||||
---
|
||||
:id: SGxz9JOWeiyH
|
||||
:outputId: b29da06c-037f-4b05-dbd8-ba52ac35a8cf
|
||||
|
||||
from jax import lax
|
||||
|
||||
operand = jnp.array([0.])
|
||||
@ -707,10 +755,9 @@ def while_loop(cond_fun, body_fun, init_val):
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: jM-D39a-c436
|
||||
outputId: b9c97167-fecf-4559-9ca7-1cb0235d8ad2
|
||||
---
|
||||
:id: jM-D39a-c436
|
||||
:outputId: b9c97167-fecf-4559-9ca7-1cb0235d8ad2
|
||||
|
||||
init_val = 0
|
||||
cond_fun = lambda x: x<10
|
||||
body_fun = lambda x: x+1
|
||||
@ -731,10 +778,9 @@ def fori_loop(start, stop, body_fun, init_val):
|
||||
```
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: dt3tUpOmeR8u
|
||||
outputId: 864f2959-2429-4666-b364-4baf90a57482
|
||||
---
|
||||
:id: dt3tUpOmeR8u
|
||||
:outputId: 864f2959-2429-4666-b364-4baf90a57482
|
||||
|
||||
init_val = 0
|
||||
start = 0
|
||||
stop = 10
|
||||
@ -906,10 +952,9 @@ When this code sees a nan in the output of an `@jit` function, it calls into the
|
||||
At the moment, JAX by default enforces single-precision numbers to mitigate the Numpy API's tendency to aggressively promote operands to `double`. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: CNNGtzM3NDkO
|
||||
outputId: d1384021-d9bf-450f-a9ae-82024fa5fc1a
|
||||
---
|
||||
:id: CNNGtzM3NDkO
|
||||
:outputId: d1384021-d9bf-450f-a9ae-82024fa5fc1a
|
||||
|
||||
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
|
||||
x.dtype
|
||||
```
|
||||
@ -951,10 +996,9 @@ Note that #2-#4 work for _any_ of JAX's configuration options.
|
||||
We can then confirm that `x64` mode is enabled:
|
||||
|
||||
```{code-cell} ipython3
|
||||
---
|
||||
id: HqGbBa9Rr-2g
|
||||
outputId: cd241d63-3d00-4fd7-f9c0-afc6af01ecf4
|
||||
---
|
||||
:id: HqGbBa9Rr-2g
|
||||
:outputId: cd241d63-3d00-4fd7-f9c0-afc6af01ecf4
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
|
||||
|
Loading…
x
Reference in New Issue
Block a user