Fix nomenclature in type promotion doc

This commit is contained in:
Jake VanderPlas 2022-02-15 13:26:06 -08:00
parent bf3c658114
commit e82f232ea9
2 changed files with 528 additions and 488 deletions

File diff suppressed because it is too large Load Diff

Before

Width:  |  Height:  |  Size: 23 KiB

After

Width:  |  Height:  |  Size: 22 KiB

View File

@ -16,18 +16,18 @@ JAX's type promotion behavior is determined via the following type promotion lat
lattice = {
'b1': ['i*'], 'u1': ['u2', 'i2'], 'u2': ['i4', 'u4'], 'u4': ['u8', 'i8'], 'u8': ['f*'],
'i*': ['u1', 'i1'], 'i1': ['i2'], 'i2': ['i4'], 'i4': ['i8'], 'i8': ['f*'],
'f*': ['c*', 'f2', 'bf'], 'bf': ['f4'], 'f2': ['f4'], 'f4': ['c4', 'f8'], 'f8': ['c8'],
'c*': ['c4'], 'c4': ['c8'], 'c8': [],
'f*': ['c*', 'f2', 'bf'], 'bf': ['f4'], 'f2': ['f4'], 'f4': ['c8', 'f8'], 'f8': ['c16'],
'c*': ['c8'], 'c8': ['c16'], 'c16': [],
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
'b1': [0, 0], 'u1': [2, 0], 'u2': [3, 0], 'u4': [4, 0], 'u8': [5, 0],
'i*': [1, 1], 'i1': [2, 2], 'i2': [3, 2], 'i4': [4, 2], 'i8': [5, 2],
'f*': [6, 1], 'bf': [7.5, 0.6], 'f2': [7.5, 1.4], 'f4': [9, 1], 'f8': [10, 1],
'c*': [7, 2], 'c4': [10, 2], 'c8': [11, 2],
'c*': [7, 2], 'c8': [10, 2], 'c16': [11, 2],
}
fig, ax = plt.subplots(figsize=(8, 2.5))
nx.draw(graph, with_labels=True, node_size=600, node_color='lightgray', pos=pos, ax=ax)
fig, ax = plt.subplots(figsize=(8, 2.6))
nx.draw(graph, with_labels=True, node_size=650, node_color='lightgray', pos=pos, ax=ax)
fig.savefig('type_lattice.svg', bbox_inches='tight')
where, for example:
@ -37,7 +37,7 @@ where, for example:
* ``u4`` means :code:`np.uint32`,
* ``bf`` means :code:`np.bfloat16`,
* ``f2`` means :code:`np.float16`,
* ``c8`` means :code:`np.complex128`,
* ``c8`` means :code:`np.complex64`,
* ``i*`` means Python :code:`int` or weakly-typed :code:`int`,
* ``f*`` means Python :code:`float` or weakly-typed :code:`float`, and
* ``c*`` means Python :code:`complex` or weakly-typed :code:`complex`.
@ -74,30 +74,31 @@ on this lattice, which generates the following binary promotion table:
</style>
<table id="types">
<tr><th></th><th>b1</th><th>u1</th><th>u2</th><th>u4</th><th>u8</th><th>i1</th><th>i2</th><th>i4</th><th>i8</th><th>bf</th><th>f2</th><th>f4</th><th>f8</th><th>c4</th><th>c8</th><th>i*</th><th>f*</th><th>c*</th></tr>
<tr><td>b1</td><td>b1</td><td>u1</td><td>u2</td><td>u4</td><td>u8</td><td>i1</td><td>i2</td><td>i4</td><td>i8</td><td class="d">bf</td><td>f2</td><td>f4</td><td>f8</td><td>c4</td><td>c8</td><td>i8</td><td>f8</td><td>c8</td></tr>
<tr><td>u1</td><td>u1</td><td>u1</td><td>u2</td><td>u4</td><td>u8</td><td>i2</td><td>i2</td><td>i4</td><td>i8</td><td class="d">bf</td><td>f2</td><td>f4</td><td>f8</td><td>c4</td><td>c8</td><td class="d">u1</td><td>f8</td><td>c8</td></tr>
<tr><td>u2</td><td>u2</td><td>u2</td><td>u2</td><td>u4</td><td>u8</td><td>i4</td><td>i4</td><td>i4</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td>f4</td><td>f8</td><td>c4</td><td>c8</td><td class="d">u2</td><td>f8</td><td>c8</td></tr>
<tr><td>u4</td><td>u4</td><td>u4</td><td>u4</td><td>u4</td><td>u8</td><td>i8</td><td>i8</td><td>i8</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td class="d">f4</td><td>f8</td><td class="d">c4</td><td>c8</td><td class="d">u4</td><td>f8</td><td>c8</td></tr>
<tr><td>u8</td><td>u8</td><td>u8</td><td>u8</td><td>u8</td><td>u8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td class="d">bf</td><td class="d">f2</td><td class="d">f4</td><td>f8</td><td class="d">c4</td><td>c8</td><td class="d">u8</td><td>f8</td><td>c8</td></tr>
<tr><td>i1</td><td>i1</td><td>i2</td><td>i4</td><td>i8</td><td>f8</td><td>i1</td><td>i2</td><td>i4</td><td>i8</td><td class="d">bf</td><td>f2</td><td>f4</td><td>f8</td><td>c4</td><td>c8</td><td class="d">i1</td><td>f8</td><td>c8</td></tr>
<tr><td>i2</td><td>i2</td><td>i2</td><td>i4</td><td>i8</td><td>f8</td><td>i2</td><td>i2</td><td>i4</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td>f4</td><td>f8</td><td>c4</td><td>c8</td><td class="d">i2</td><td>f8</td><td>c8</td></tr>
<tr><td>i4</td><td>i4</td><td>i4</td><td>i4</td><td>i8</td><td>f8</td><td>i4</td><td>i4</td><td>i4</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td class="d">f4</td><td>f8</td><td class="d">c4</td><td>c8</td><td class="d">i4</td><td>f8</td><td>c8</td></tr>
<tr><td>i8</td><td>i8</td><td>i8</td><td>i8</td><td>i8</td><td>f8</td><td>i8</td><td>i8</td><td>i8</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td class="d">f4</td><td>f8</td><td class="d">c4</td><td>c8</td><td>i8</td><td>f8</td><td>c8</td></tr>
<tr><td>bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">f4</td><td class="d">f4</td><td class="d">f8</td><td class="d">c4</td><td class="d">c8</td><td class="d">bf</td><td class="d">bf</td><td class="d">c4</td></tr>
<tr><td>f2</td><td>f2</td><td>f2</td><td class="d">f2</td><td class="d">f2</td><td class="d">f2</td><td>f2</td><td class="d">f2</td><td class="d">f2</td><td class="d">f2</td><td class="d">f4</td><td>f2</td><td>f4</td><td>f8</td><td>c4</td><td>c8</td><td class="d">f2</td><td class="d">f2</td><td class="d">c4</td></tr>
<tr><td>f4</td><td>f4</td><td>f4</td><td>f4</td><td class="d">f4</td><td class="d">f4</td><td>f4</td><td>f4</td><td class="d">f4</td><td class="d">f4</td><td class="d">f4</td><td>f4</td><td>f4</td><td>f8</td><td>c4</td><td>c8</td><td class="d">f4</td><td class="d">f4</td><td class="d">c4</td></tr>
<tr><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td class="d">f8</td><td>f8</td><td>f8</td><td>f8</td><td>c8</td><td>c8</td><td>f8</td><td>f8</td><td>c8</td></tr>
<tr><td>c4</td><td>c4</td><td>c4</td><td>c4</td><td class="d">c4</td><td class="d">c4</td><td>c4</td><td>c4</td><td class="d">c4</td><td class="d">c4</td><td class="d">c4</td><td>c4</td><td>c4</td><td>c8</td><td>c4</td><td>c8</td><td class="d">c4</td><td class="d">c4</td><td class="d">c4</td></tr>
<tr><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td class="d">c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td></tr>
<tr><td>i*</td><td>i8</td><td class="d">u1</td><td class="d">u2</td><td class="d">u4</td><td class="d">u8</td><td class="d">i1</td><td class="d">i2</td><td class="d">i4</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td class="d">f4</td><td>f8</td><td class="d">c4</td><td>c8</td><td>i8</td><td>f8</td><td>c8</td></tr>
<tr><td>f*</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td class="d">bf</td><td class="d">f2</td><td class="d">f4</td><td>f8</td><td class="d">c4</td><td>c8</td><td>f8</td><td>f8</td><td>c8</td></tr>
<tr><td>c*</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td class="d">c4</td><td class="d">c4</td><td class="d">c4</td><td>c8</td><td class="d">c4</td><td>c8</td><td>c8</td><td>c8</td><td>c8</td></tr>
<tr><th></th><th>b1</th><th>u1</th><th>u2</th><th>u4</th><th>u8</th><th>i1</th><th>i2</th><th>i4</th><th>i8</th><th>bf</th><th>f2</th><th>f4</th><th>f8</th><th>c8</th><th>c16</th><th>i*</th><th>f*</th><th>c*</th></tr>
<tr><td>b1</td><td>b1</td><td>u1</td><td>u2</td><td>u4</td><td>u8</td><td>i1</td><td>i2</td><td>i4</td><td>i8</td><td class="d">bf</td><td>f2</td><td>f4</td><td>f8</td><td>c8</td><td>c16</td><td>i*</td><td>f*</td><td>c*</td></tr>
<tr><td>u1</td><td>u1</td><td>u1</td><td>u2</td><td>u4</td><td>u8</td><td>i2</td><td>i2</td><td>i4</td><td>i8</td><td class="d">bf</td><td>f2</td><td>f4</td><td>f8</td><td>c8</td><td>c16</td><td class="d">u1</td><td>f*</td><td>c*</td></tr>
<tr><td>u2</td><td>u2</td><td>u2</td><td>u2</td><td>u4</td><td>u8</td><td>i4</td><td>i4</td><td>i4</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td>f4</td><td>f8</td><td>c8</td><td>c16</td><td class="d">u2</td><td>f*</td><td>c*</td></tr>
<tr><td>u4</td><td>u4</td><td>u4</td><td>u4</td><td>u4</td><td>u8</td><td>i8</td><td>i8</td><td>i8</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td class="d">f4</td><td>f8</td><td class="d">c8</td><td>c16</td><td class="d">u4</td><td>f*</td><td>c*</td></tr>
<tr><td>u8</td><td>u8</td><td>u8</td><td>u8</td><td>u8</td><td>u8</td><td>f*</td><td>f*</td><td>f*</td><td>f*</td><td class="d">bf</td><td class="d">f2</td><td class="d">f4</td><td>f8</td><td class="d">c8</td><td>c16</td><td class="d">u8</td><td>f*</td><td>c*</td></tr>
<tr><td>i1</td><td>i1</td><td>i2</td><td>i4</td><td>i8</td><td>f*</td><td>i1</td><td>i2</td><td>i4</td><td>i8</td><td class="d">bf</td><td>f2</td><td>f4</td><td>f8</td><td>c8</td><td>c16</td><td class="d">i1</td><td>f*</td><td>c*</td></tr>
<tr><td>i2</td><td>i2</td><td>i2</td><td>i4</td><td>i8</td><td>f*</td><td>i2</td><td>i2</td><td>i4</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td>f4</td><td>f8</td><td>c8</td><td>c16</td><td class="d">i2</td><td>f*</td><td>c*</td></tr>
<tr><td>i4</td><td>i4</td><td>i4</td><td>i4</td><td>i8</td><td>f*</td><td>i4</td><td>i4</td><td>i4</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td class="d">f4</td><td>f8</td><td class="d">c8</td><td>c16</td><td class="d">i4</td><td>f*</td><td>c*</td></tr>
<tr><td>i8</td><td>i8</td><td>i8</td><td>i8</td><td>i8</td><td>f*</td><td>i8</td><td>i8</td><td>i8</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td class="d">f4</td><td>f8</td><td class="d">c8</td><td>c16</td><td>i8</td><td>f*</td><td>c*</td></tr>
<tr><td>bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">bf</td><td class="d">f4</td><td class="d">f4</td><td class="d">f8</td><td class="d">c8</td><td class="d">c16</td><td class="d">bf</td><td class="d">bf</td><td class="d">c8</td></tr>
<tr><td>f2</td><td>f2</td><td>f2</td><td class="d">f2</td><td class="d">f2</td><td class="d">f2</td><td>f2</td><td class="d">f2</td><td class="d">f2</td><td class="d">f2</td><td class="d">f4</td><td>f2</td><td>f4</td><td>f8</td><td>c8</td><td>c16</td><td class="d">f2</td><td class="d">f2</td><td class="d">c8</td></tr>
<tr><td>f4</td><td>f4</td><td>f4</td><td>f4</td><td class="d">f4</td><td class="d">f4</td><td>f4</td><td>f4</td><td class="d">f4</td><td class="d">f4</td><td class="d">f4</td><td>f4</td><td>f4</td><td>f8</td><td>c8</td><td>c16</td><td class="d">f4</td><td class="d">f4</td><td class="d">c8</td></tr>
<tr><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td>f8</td><td class="d">f8</td><td>f8</td><td>f8</td><td>f8</td><td>c16</td><td>c16</td><td>f8</td><td>f8</td><td>c16</td></tr>
<tr><td>c8</td><td>c8</td><td>c8</td><td>c8</td><td class="d">c8</td><td class="d">c8</td><td>c8</td><td>c8</td><td class="d">c8</td><td class="d">c8</td><td class="d">c8</td><td>c8</td><td>c8</td><td>c16</td><td>c8</td><td>c16</td><td class="d">c8</td><td class="d">c8</td><td class="d">c8</td></tr>
<tr><td>c16</td><td>c16</td><td>c16</td><td>c16</td><td>c16</td><td>c16</td><td>c16</td><td>c16</td><td>c16</td><td>c16</td><td class="d">c16</td><td>c16</td><td>c16</td><td>c16</td><td>c16</td><td>c16</td><td>c16</td><td>c16</td><td>c16</td></tr>
<tr><td>i*</td><td>i*</td><td class="d">u1</td><td class="d">u2</td><td class="d">u4</td><td class="d">u8</td><td class="d">i1</td><td class="d">i2</td><td class="d">i4</td><td>i8</td><td class="d">bf</td><td class="d">f2</td><td class="d">f4</td><td>f8</td><td class="d">c8</td><td>c16</td><td>i*</td><td>f*</td><td>c*</td></tr>
<tr><td>f*</td><td>f*</td><td>f*</td><td>f*</td><td>f*</td><td>f*</td><td>f*</td><td>f*</td><td>f*</td><td>f*</td><td class="d">bf</td><td class="d">f2</td><td class="d">f4</td><td>f8</td><td class="d">c8</td><td>c16</td><td>f*</td><td>f*</td><td>c*</td></tr>
<tr><td>c*</td><td>c*</td><td>c*</td><td>c*</td><td>c*</td><td>c*</td><td>c*</td><td>c*</td><td>c*</td><td>c*</td><td class="d">c8</td><td class="d">c8</td><td class="d">c8</td><td>c16</td><td class="d">c8</td><td>c16</td><td>c*</td><td>c*</td><td>c*</td></tr>
</table><p>
.. The table above was generated by the following Python code.
import numpy as np
import jax.numpy as jnp
from jax._src import dtypes
types = [np.bool_, np.uint8, np.uint16, np.uint32, np.uint64,
np.int8, np.int16, np.int32, np.int64,
@ -105,14 +106,10 @@ on this lattice, which generates the following binary promotion table:
np.complex64, np.complex128, int, float, complex]
def name(d):
if d in {int, float, complex}:
return f"{d.__name__[0]}*"
d = np.dtype(d)
if d == np.dtype(jnp.bfloat16):
if d == jnp.bfloat16:
return "bf"
return "{}{}".format(
d.kind,
d.itemsize // 2 if np.issubdtype(d, np.complexfloating) else d.itemsize)
itemsize = "*" if d in {int, float, complex} else np.dtype(d).itemsize
return f"{np.dtype(d).kind}{itemsize}"
out = "<tr><th></th>"
for t in types:
@ -122,8 +119,10 @@ on this lattice, which generates the following binary promotion table:
for t1 in types:
out += "<tr><td>{}</td>".format(name(t1))
for t2 in types:
t = jnp.promote_types(t1, t2)
different = jnp.bfloat16 in (t1, t2) or t is not np.promote_types(t1, t2)
t, weak_type = dtypes._lattice_result_type(t1, t2)
if weak_type:
t = type(t.type(0).item())
different = jnp.bfloat16 in (t1, t2) or jnp.promote_types(t1, t2) is not np.promote_types(t1, t2)
out += "<td{}>{}</td>".format(" class=\"d\"" if different else "", name(t))
out += "</tr>\n"
@ -131,7 +130,7 @@ on this lattice, which generates the following binary promotion table:
Jax's type promotion rules differ from those of NumPy, as given by
:func:`numpy.promote_types`, in those cells highlighted with a green background
in the table above. There are three key differences:
in the table above. There are three key classes of differences:
* When promoting a weakly typed value against a typed JAX value of the same category,
JAX always prefers the precision of the JAX value. For example, ``jnp.int16(1) + 1``