Add jax_numpy_dtype_promotion='strict' mode

This commit is contained in:
Jake VanderPlas 2022-05-26 10:56:09 -07:00
parent 563a6337fa
commit ceae6fe5e2
4 changed files with 142 additions and 41 deletions

View File

@ -47,6 +47,7 @@ from jax._src.config import (
log_compiles as log_compiles,
default_matmul_precision as default_matmul_precision,
default_prng_impl as default_prng_impl,
numpy_dtype_promotion as numpy_dtype_promotion,
numpy_rank_promotion as numpy_rank_promotion,
jax2tf_associative_scan_reductions as jax2tf_associative_scan_reductions,
transfer_guard as transfer_guard,

View File

@ -350,7 +350,8 @@ class Config:
Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately."""
return (self.x64_enabled, self.jax_numpy_rank_promotion,
self.jax_default_matmul_precision, self.jax_dynamic_shapes)
self.jax_default_matmul_precision, self.jax_dynamic_shapes,
self.jax_numpy_dtype_promotion)
class NoDefault: pass
no_default = NoDefault()
@ -437,6 +438,7 @@ already_configured_with_absl = False
class GlobalJitState(NamedTuple):
numpy_rank_promotion: Optional[str] = None
numpy_dtype_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
dynamic_shapes: bool = False
@ -450,6 +452,7 @@ def update_global_jit_state(**kw):
class ThreadLocalJitState(NamedTuple):
dynamic_trace_state: Optional[Any] = None
numpy_rank_promotion: Optional[str] = None
numpy_dtype_promotion: Optional[str] = None
default_matmul_precision: Optional[Any] = None
dynamic_shapes: bool = False
@ -627,6 +630,19 @@ config.define_enum_state(
'This is a temporary flag that will be used during the process '
'of deprecating the ``jax_enable_x64`` flag.'))
numpy_dtype_promotion = config.define_enum_state(
name='jax_numpy_dtype_promotion',
enum_values=['standard', 'strict'],
default='standard',
help=('Specify the rules used for implicit type promotion in operations '
'between arrays. Options are "standard" or "strict"; in strict-mode, '
'binary operations between arrays of differing strongly-specified '
'dtypes will result in an error.'),
update_global_hook=lambda val: \
update_global_jit_state(numpy_dtype_promotion=val),
update_thread_local_hook=lambda val: \
update_thread_local_jit_state(numpy_dtype_promotion=val))
def _update_x64_global(val):
lib.jax_jit.global_state().enable_x64 = val

View File

@ -21,7 +21,7 @@
import functools
from typing import Any, Dict
from typing import Any, Dict, List
import numpy as np
@ -243,24 +243,29 @@ issubsctype = np.issubsctype
# Enumeration of all valid JAX types in order.
_weak_types = [int, float, complex]
_jax_types = [
np.dtype('bool'),
np.dtype('uint8'),
np.dtype('uint16'),
np.dtype('uint32'),
np.dtype('uint64'),
np.dtype('int8'),
np.dtype('int16'),
np.dtype('int32'),
np.dtype('int64'),
np.dtype(bfloat16),
np.dtype('float16'),
np.dtype('float32'),
np.dtype('float64'),
np.dtype('complex64'),
np.dtype('complex128'),
_bool_types: List[np.dtype] = [np.dtype(bool)]
_int_types: List[np.dtype] = [
np.dtype('uint8'),
np.dtype('uint16'),
np.dtype('uint32'),
np.dtype('uint64'),
np.dtype('int8'),
np.dtype('int16'),
np.dtype('int32'),
np.dtype('int64'),
]
_jax_dtype_set = set(_jax_types) | {float0}
_float_types: List[np.dtype] = [
np.dtype(bfloat16),
np.dtype('float16'),
np.dtype('float32'),
np.dtype('float64'),
]
_complex_types: List[np.dtype] = [
np.dtype('complex64'),
np.dtype('complex128'),
]
_jax_types = _bool_types + _int_types + _float_types + _complex_types
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types}
def _jax_type(dtype, weak_type):
"""Return the jax type for a dtype and weak type."""
@ -270,23 +275,37 @@ def _dtype_and_weaktype(value):
"""Return a (dtype, weak_type) tuple for the given input."""
return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
def _type_promotion_lattice():
def _type_promotion_lattice(jax_numpy_dtype_promotion):
"""
Return the type promotion lattice in the form of a DAG.
This DAG maps each type to its immediately higher type on the lattice.
"""
b1, u1, u2, u4, u8, i1, i2, i4, i8, bf, f2, f4, f8, c4, c8 = _jax_types
b1, = _bool_types
u1, u2, u4, u8, i1, i2, i4, i8 = _int_types
bf, f2, f4, f8 = _float_types
c4, c8 = _complex_types
i_, f_, c_ = _weak_types
return {
b1: [i_],
u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
f_: [bf, f2, c_], bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
c_: [c4], c4: [c8], c8: [],
}
if jax_numpy_dtype_promotion == 'standard':
return {
b1: [i_],
u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
f_: [bf, f2, c_], bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
c_: [c4], c4: [c8], c8: [],
}
elif jax_numpy_dtype_promotion == 'strict':
return {
i_: [f_] + _int_types,
f_: [c_] + _float_types,
c_: _complex_types,
**{t: [] for t in _jax_types}
}
else:
raise ValueError(
f"Unexpected value of jax_numpy_dtype_promotion={jax_numpy_dtype_promotion!r}")
def _make_lattice_upper_bounds():
lattice = _type_promotion_lattice()
def _make_lattice_upper_bounds(jax_numpy_dtype_promotion):
lattice = _type_promotion_lattice(jax_numpy_dtype_promotion)
upper_bounds = {node: {node} for node in lattice}
for n in lattice:
while True:
@ -297,10 +316,17 @@ def _make_lattice_upper_bounds():
break
upper_bounds[n] |= new_upper_bounds
return upper_bounds
_lattice_upper_bounds = _make_lattice_upper_bounds()
_lattice_upper_bounds = {
'standard': _make_lattice_upper_bounds('standard'),
'strict': _make_lattice_upper_bounds('strict'),
}
class TypePromotionError(ValueError):
pass
@functools.lru_cache(512) # don't use util.memoize because there is no X64 dependence.
def _least_upper_bound(*nodes):
def _least_upper_bound(jax_numpy_dtype_promotion, *nodes):
"""Compute the least upper bound of a set of nodes.
Args:
@ -327,13 +353,23 @@ def _least_upper_bound(*nodes):
# ∀ c ∈ N: CUB(N) ⊆ UB(c)
# So if N ∩ CUB(N) is nonempty, if follows that LUB(N) = N ∩ CUB(N).
N = set(nodes)
UB = _lattice_upper_bounds
UB = _lattice_upper_bounds[jax_numpy_dtype_promotion]
CUB = set.intersection(*(UB[n] for n in N))
LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])}
if len(LUB) == 1:
return LUB.pop()
elif len(LUB) == 0:
# TODO(jakevdp): surface some error about jax_numpy_rank_promotion flag.
raise TypePromotionError(
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
"promotion path. Try explicitly casting inputs to the desired output type.")
else:
raise ValueError(f"{nodes} do not have a unique least upper bound.")
# If we get here, it means the lattice is ill-formed.
raise TypePromotionError(
f"Internal Type Promotion error: {nodes} do not have a unique least upper bound "
f"on the specified lattice; options are {LUB}. If you see this error, please "
"report it to the JAX maintainers."
)
def promote_types(a, b):
"""Returns the type to which a binary operation should cast its arguments.
@ -351,7 +387,7 @@ def promote_types(a, b):
# object identity, not object equality, due to the behavior of np.dtype.__eq__
a = a if any(a is t for t in _weak_types) else np.dtype(a)
b = b if any(b is t for t in _weak_types) else np.dtype(b)
return np.dtype(_least_upper_bound(a, b))
return np.dtype(_least_upper_bound(config.jax_numpy_dtype_promotion, a, b))
def is_weakly_typed(x):
try:
@ -388,11 +424,14 @@ def _lattice_result_type(*args):
# If all inputs are weakly typed, we compute the bound of the strongly-typed
# counterparts and apply the weak type at the end. This avoids returning the
# incorrect result with non-canonical weak types (e.g. weak int16).
if all(weak_types):
result_type = _least_upper_bound(*{_jax_type(dtype, False) for dtype in dtypes})
# TODO(jakevdp): explore removing this special case.
if all(weak_types) and config.jax_numpy_dtype_promotion != 'strict':
result_type = _least_upper_bound(config.jax_numpy_dtype_promotion,
*{_jax_type(dtype, False) for dtype in dtypes})
return dtype(result_type), True
else:
result_type = _least_upper_bound(*{_jax_type(d, w) for d, w in zip(dtypes, weak_types)})
result_type = _least_upper_bound(config.jax_numpy_dtype_promotion,
*{_jax_type(d, w) for d, w in zip(dtypes, weak_types)})
return dtype(result_type), any(result_type is t for t in _weak_types)
def result_type(*args, return_weak_type_flag=False):

View File

@ -29,7 +29,7 @@ from jax import numpy as jnp
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
from jax.config import config
from jax._src.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
@ -136,7 +136,44 @@ class DtypesTest(jtu.JaxTestCase):
self.assertTrue(isinstance(z, jnp.ndarray), msg=(x, y, z))
self.assertEqual(z.dtype, dtypes.canonicalize_dtype(dtype), msg=(x, y, z))
def testPromoteDtypes(self):
@jax.numpy_dtype_promotion('strict')
def testPromoteDtypesStrict(self):
# Check that strong types have diagonal promotion table:
for t1 in all_dtypes:
for t2 in all_dtypes:
if t1 == t2:
self.assertEqual(t1, dtypes.promote_types(t1, t2))
else:
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, t2)
# Promotion between weak types matches numpy promotion
for t1 in [int, float, complex]:
for t2 in [int, float, complex]:
py_result = type(t1(0) + t2(0))
lattice_dtype, lattice_weak_type = dtypes._lattice_result_type(t1, t2)
self.assertTrue(lattice_weak_type)
self.assertEqual(lattice_dtype, np.dtype(py_result))
# Check that weak promotion only works if strong value is not cast:
for t1 in bool_dtypes:
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, int)
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, float)
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, complex)
for t1 in signed_dtypes + unsigned_dtypes:
self.assertEqual(dtypes.promote_types(t1, int), t1)
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, float)
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, complex)
for t1 in float_dtypes:
self.assertEqual(dtypes.promote_types(t1, int), t1)
self.assertEqual(dtypes.promote_types(t1, float), t1)
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, complex)
for t1 in complex_dtypes:
self.assertEqual(dtypes.promote_types(t1, int), t1)
self.assertEqual(dtypes.promote_types(t1, float), t1)
self.assertEqual(dtypes.promote_types(t1, complex), t1)
@jax.numpy_dtype_promotion('standard')
def testPromoteDtypesStandard(self):
for t1 in all_dtypes:
self.assertEqual(t1, dtypes.promote_types(t1, t1))
@ -163,7 +200,15 @@ class DtypesTest(jtu.JaxTestCase):
np_float_dtypes + complex_dtypes]:
for t1, t2 in itertools.combinations(groups, 2):
self.assertEqual(np.promote_types(t1, t2),
dtypes.promote_types(t1, t2))
dtypes.promote_types(t1, t2))
# Promotion between weak types matches numpy promotion
for t1 in [int, float, complex]:
for t2 in [int, float, complex]:
py_result = type(t1(0) + t2(0))
lattice_dtype, lattice_weak_type = dtypes._lattice_result_type(t1, t2)
self.assertTrue(lattice_weak_type)
self.assertEqual(lattice_dtype, np.dtype(py_result))
@parameterized.parameters([jnp.bool_, jnp.int32, jnp.bfloat16, jnp.float32, jnp.complex64])
def testScalarInstantiation(self, scalar_type):