mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Add jax_numpy_dtype_promotion='strict' mode
This commit is contained in:
parent
563a6337fa
commit
ceae6fe5e2
@ -47,6 +47,7 @@ from jax._src.config import (
|
||||
log_compiles as log_compiles,
|
||||
default_matmul_precision as default_matmul_precision,
|
||||
default_prng_impl as default_prng_impl,
|
||||
numpy_dtype_promotion as numpy_dtype_promotion,
|
||||
numpy_rank_promotion as numpy_rank_promotion,
|
||||
jax2tf_associative_scan_reductions as jax2tf_associative_scan_reductions,
|
||||
transfer_guard as transfer_guard,
|
||||
|
@ -350,7 +350,8 @@ class Config:
|
||||
Values included in this set should also most likely be included in
|
||||
the C++ JIT state, which is handled separately."""
|
||||
return (self.x64_enabled, self.jax_numpy_rank_promotion,
|
||||
self.jax_default_matmul_precision, self.jax_dynamic_shapes)
|
||||
self.jax_default_matmul_precision, self.jax_dynamic_shapes,
|
||||
self.jax_numpy_dtype_promotion)
|
||||
|
||||
class NoDefault: pass
|
||||
no_default = NoDefault()
|
||||
@ -437,6 +438,7 @@ already_configured_with_absl = False
|
||||
|
||||
class GlobalJitState(NamedTuple):
|
||||
numpy_rank_promotion: Optional[str] = None
|
||||
numpy_dtype_promotion: Optional[str] = None
|
||||
default_matmul_precision: Optional[Any] = None
|
||||
dynamic_shapes: bool = False
|
||||
|
||||
@ -450,6 +452,7 @@ def update_global_jit_state(**kw):
|
||||
class ThreadLocalJitState(NamedTuple):
|
||||
dynamic_trace_state: Optional[Any] = None
|
||||
numpy_rank_promotion: Optional[str] = None
|
||||
numpy_dtype_promotion: Optional[str] = None
|
||||
default_matmul_precision: Optional[Any] = None
|
||||
dynamic_shapes: bool = False
|
||||
|
||||
@ -627,6 +630,19 @@ config.define_enum_state(
|
||||
'This is a temporary flag that will be used during the process '
|
||||
'of deprecating the ``jax_enable_x64`` flag.'))
|
||||
|
||||
numpy_dtype_promotion = config.define_enum_state(
|
||||
name='jax_numpy_dtype_promotion',
|
||||
enum_values=['standard', 'strict'],
|
||||
default='standard',
|
||||
help=('Specify the rules used for implicit type promotion in operations '
|
||||
'between arrays. Options are "standard" or "strict"; in strict-mode, '
|
||||
'binary operations between arrays of differing strongly-specified '
|
||||
'dtypes will result in an error.'),
|
||||
update_global_hook=lambda val: \
|
||||
update_global_jit_state(numpy_dtype_promotion=val),
|
||||
update_thread_local_hook=lambda val: \
|
||||
update_thread_local_jit_state(numpy_dtype_promotion=val))
|
||||
|
||||
def _update_x64_global(val):
|
||||
lib.jax_jit.global_state().enable_x64 = val
|
||||
|
||||
|
@ -21,7 +21,7 @@
|
||||
|
||||
|
||||
import functools
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -243,24 +243,29 @@ issubsctype = np.issubsctype
|
||||
|
||||
# Enumeration of all valid JAX types in order.
|
||||
_weak_types = [int, float, complex]
|
||||
_jax_types = [
|
||||
np.dtype('bool'),
|
||||
np.dtype('uint8'),
|
||||
np.dtype('uint16'),
|
||||
np.dtype('uint32'),
|
||||
np.dtype('uint64'),
|
||||
np.dtype('int8'),
|
||||
np.dtype('int16'),
|
||||
np.dtype('int32'),
|
||||
np.dtype('int64'),
|
||||
np.dtype(bfloat16),
|
||||
np.dtype('float16'),
|
||||
np.dtype('float32'),
|
||||
np.dtype('float64'),
|
||||
np.dtype('complex64'),
|
||||
np.dtype('complex128'),
|
||||
_bool_types: List[np.dtype] = [np.dtype(bool)]
|
||||
_int_types: List[np.dtype] = [
|
||||
np.dtype('uint8'),
|
||||
np.dtype('uint16'),
|
||||
np.dtype('uint32'),
|
||||
np.dtype('uint64'),
|
||||
np.dtype('int8'),
|
||||
np.dtype('int16'),
|
||||
np.dtype('int32'),
|
||||
np.dtype('int64'),
|
||||
]
|
||||
_jax_dtype_set = set(_jax_types) | {float0}
|
||||
_float_types: List[np.dtype] = [
|
||||
np.dtype(bfloat16),
|
||||
np.dtype('float16'),
|
||||
np.dtype('float32'),
|
||||
np.dtype('float64'),
|
||||
]
|
||||
_complex_types: List[np.dtype] = [
|
||||
np.dtype('complex64'),
|
||||
np.dtype('complex128'),
|
||||
]
|
||||
_jax_types = _bool_types + _int_types + _float_types + _complex_types
|
||||
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types}
|
||||
|
||||
def _jax_type(dtype, weak_type):
|
||||
"""Return the jax type for a dtype and weak type."""
|
||||
@ -270,23 +275,37 @@ def _dtype_and_weaktype(value):
|
||||
"""Return a (dtype, weak_type) tuple for the given input."""
|
||||
return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
|
||||
|
||||
def _type_promotion_lattice():
|
||||
def _type_promotion_lattice(jax_numpy_dtype_promotion):
|
||||
"""
|
||||
Return the type promotion lattice in the form of a DAG.
|
||||
This DAG maps each type to its immediately higher type on the lattice.
|
||||
"""
|
||||
b1, u1, u2, u4, u8, i1, i2, i4, i8, bf, f2, f4, f8, c4, c8 = _jax_types
|
||||
b1, = _bool_types
|
||||
u1, u2, u4, u8, i1, i2, i4, i8 = _int_types
|
||||
bf, f2, f4, f8 = _float_types
|
||||
c4, c8 = _complex_types
|
||||
i_, f_, c_ = _weak_types
|
||||
return {
|
||||
b1: [i_],
|
||||
u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
|
||||
i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
|
||||
f_: [bf, f2, c_], bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
|
||||
c_: [c4], c4: [c8], c8: [],
|
||||
}
|
||||
if jax_numpy_dtype_promotion == 'standard':
|
||||
return {
|
||||
b1: [i_],
|
||||
u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
|
||||
i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
|
||||
f_: [bf, f2, c_], bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
|
||||
c_: [c4], c4: [c8], c8: [],
|
||||
}
|
||||
elif jax_numpy_dtype_promotion == 'strict':
|
||||
return {
|
||||
i_: [f_] + _int_types,
|
||||
f_: [c_] + _float_types,
|
||||
c_: _complex_types,
|
||||
**{t: [] for t in _jax_types}
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected value of jax_numpy_dtype_promotion={jax_numpy_dtype_promotion!r}")
|
||||
|
||||
def _make_lattice_upper_bounds():
|
||||
lattice = _type_promotion_lattice()
|
||||
def _make_lattice_upper_bounds(jax_numpy_dtype_promotion):
|
||||
lattice = _type_promotion_lattice(jax_numpy_dtype_promotion)
|
||||
upper_bounds = {node: {node} for node in lattice}
|
||||
for n in lattice:
|
||||
while True:
|
||||
@ -297,10 +316,17 @@ def _make_lattice_upper_bounds():
|
||||
break
|
||||
upper_bounds[n] |= new_upper_bounds
|
||||
return upper_bounds
|
||||
_lattice_upper_bounds = _make_lattice_upper_bounds()
|
||||
|
||||
_lattice_upper_bounds = {
|
||||
'standard': _make_lattice_upper_bounds('standard'),
|
||||
'strict': _make_lattice_upper_bounds('strict'),
|
||||
}
|
||||
|
||||
class TypePromotionError(ValueError):
|
||||
pass
|
||||
|
||||
@functools.lru_cache(512) # don't use util.memoize because there is no X64 dependence.
|
||||
def _least_upper_bound(*nodes):
|
||||
def _least_upper_bound(jax_numpy_dtype_promotion, *nodes):
|
||||
"""Compute the least upper bound of a set of nodes.
|
||||
|
||||
Args:
|
||||
@ -327,13 +353,23 @@ def _least_upper_bound(*nodes):
|
||||
# ∀ c ∈ N: CUB(N) ⊆ UB(c)
|
||||
# So if N ∩ CUB(N) is nonempty, if follows that LUB(N) = N ∩ CUB(N).
|
||||
N = set(nodes)
|
||||
UB = _lattice_upper_bounds
|
||||
UB = _lattice_upper_bounds[jax_numpy_dtype_promotion]
|
||||
CUB = set.intersection(*(UB[n] for n in N))
|
||||
LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])}
|
||||
if len(LUB) == 1:
|
||||
return LUB.pop()
|
||||
elif len(LUB) == 0:
|
||||
# TODO(jakevdp): surface some error about jax_numpy_rank_promotion flag.
|
||||
raise TypePromotionError(
|
||||
f"Input dtypes {tuple(str(n) for n in nodes)} have no available implicit dtype "
|
||||
"promotion path. Try explicitly casting inputs to the desired output type.")
|
||||
else:
|
||||
raise ValueError(f"{nodes} do not have a unique least upper bound.")
|
||||
# If we get here, it means the lattice is ill-formed.
|
||||
raise TypePromotionError(
|
||||
f"Internal Type Promotion error: {nodes} do not have a unique least upper bound "
|
||||
f"on the specified lattice; options are {LUB}. If you see this error, please "
|
||||
"report it to the JAX maintainers."
|
||||
)
|
||||
|
||||
def promote_types(a, b):
|
||||
"""Returns the type to which a binary operation should cast its arguments.
|
||||
@ -351,7 +387,7 @@ def promote_types(a, b):
|
||||
# object identity, not object equality, due to the behavior of np.dtype.__eq__
|
||||
a = a if any(a is t for t in _weak_types) else np.dtype(a)
|
||||
b = b if any(b is t for t in _weak_types) else np.dtype(b)
|
||||
return np.dtype(_least_upper_bound(a, b))
|
||||
return np.dtype(_least_upper_bound(config.jax_numpy_dtype_promotion, a, b))
|
||||
|
||||
def is_weakly_typed(x):
|
||||
try:
|
||||
@ -388,11 +424,14 @@ def _lattice_result_type(*args):
|
||||
# If all inputs are weakly typed, we compute the bound of the strongly-typed
|
||||
# counterparts and apply the weak type at the end. This avoids returning the
|
||||
# incorrect result with non-canonical weak types (e.g. weak int16).
|
||||
if all(weak_types):
|
||||
result_type = _least_upper_bound(*{_jax_type(dtype, False) for dtype in dtypes})
|
||||
# TODO(jakevdp): explore removing this special case.
|
||||
if all(weak_types) and config.jax_numpy_dtype_promotion != 'strict':
|
||||
result_type = _least_upper_bound(config.jax_numpy_dtype_promotion,
|
||||
*{_jax_type(dtype, False) for dtype in dtypes})
|
||||
return dtype(result_type), True
|
||||
else:
|
||||
result_type = _least_upper_bound(*{_jax_type(d, w) for d, w in zip(dtypes, weak_types)})
|
||||
result_type = _least_upper_bound(config.jax_numpy_dtype_promotion,
|
||||
*{_jax_type(d, w) for d, w in zip(dtypes, weak_types)})
|
||||
return dtype(result_type), any(result_type is t for t in _weak_types)
|
||||
|
||||
def result_type(*args, return_weak_type_flag=False):
|
||||
|
@ -29,7 +29,7 @@ from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import lax as lax_internal
|
||||
|
||||
from jax.config import config
|
||||
from jax._src.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
FLAGS = config.FLAGS
|
||||
@ -136,7 +136,44 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
self.assertTrue(isinstance(z, jnp.ndarray), msg=(x, y, z))
|
||||
self.assertEqual(z.dtype, dtypes.canonicalize_dtype(dtype), msg=(x, y, z))
|
||||
|
||||
def testPromoteDtypes(self):
|
||||
@jax.numpy_dtype_promotion('strict')
|
||||
def testPromoteDtypesStrict(self):
|
||||
# Check that strong types have diagonal promotion table:
|
||||
for t1 in all_dtypes:
|
||||
for t2 in all_dtypes:
|
||||
if t1 == t2:
|
||||
self.assertEqual(t1, dtypes.promote_types(t1, t2))
|
||||
else:
|
||||
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, t2)
|
||||
|
||||
# Promotion between weak types matches numpy promotion
|
||||
for t1 in [int, float, complex]:
|
||||
for t2 in [int, float, complex]:
|
||||
py_result = type(t1(0) + t2(0))
|
||||
lattice_dtype, lattice_weak_type = dtypes._lattice_result_type(t1, t2)
|
||||
self.assertTrue(lattice_weak_type)
|
||||
self.assertEqual(lattice_dtype, np.dtype(py_result))
|
||||
|
||||
# Check that weak promotion only works if strong value is not cast:
|
||||
for t1 in bool_dtypes:
|
||||
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, int)
|
||||
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, float)
|
||||
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, complex)
|
||||
for t1 in signed_dtypes + unsigned_dtypes:
|
||||
self.assertEqual(dtypes.promote_types(t1, int), t1)
|
||||
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, float)
|
||||
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, complex)
|
||||
for t1 in float_dtypes:
|
||||
self.assertEqual(dtypes.promote_types(t1, int), t1)
|
||||
self.assertEqual(dtypes.promote_types(t1, float), t1)
|
||||
self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types, t1, complex)
|
||||
for t1 in complex_dtypes:
|
||||
self.assertEqual(dtypes.promote_types(t1, int), t1)
|
||||
self.assertEqual(dtypes.promote_types(t1, float), t1)
|
||||
self.assertEqual(dtypes.promote_types(t1, complex), t1)
|
||||
|
||||
@jax.numpy_dtype_promotion('standard')
|
||||
def testPromoteDtypesStandard(self):
|
||||
for t1 in all_dtypes:
|
||||
self.assertEqual(t1, dtypes.promote_types(t1, t1))
|
||||
|
||||
@ -163,7 +200,15 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
np_float_dtypes + complex_dtypes]:
|
||||
for t1, t2 in itertools.combinations(groups, 2):
|
||||
self.assertEqual(np.promote_types(t1, t2),
|
||||
dtypes.promote_types(t1, t2))
|
||||
dtypes.promote_types(t1, t2))
|
||||
|
||||
# Promotion between weak types matches numpy promotion
|
||||
for t1 in [int, float, complex]:
|
||||
for t2 in [int, float, complex]:
|
||||
py_result = type(t1(0) + t2(0))
|
||||
lattice_dtype, lattice_weak_type = dtypes._lattice_result_type(t1, t2)
|
||||
self.assertTrue(lattice_weak_type)
|
||||
self.assertEqual(lattice_dtype, np.dtype(py_result))
|
||||
|
||||
@parameterized.parameters([jnp.bool_, jnp.int32, jnp.bfloat16, jnp.float32, jnp.complex64])
|
||||
def testScalarInstantiation(self, scalar_type):
|
||||
|
Loading…
x
Reference in New Issue
Block a user