mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
This commit is contained in:
parent
3b21615536
commit
6a6f13e1b0
@ -95,7 +95,6 @@ on this lattice, which generates the following binary promotion table:
|
||||
.. The table above was generated by the following Python code.
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from jax import dtypes
|
||||
|
||||
types = [np.bool_, np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
np.int8, np.int16, np.int32, np.int64,
|
||||
|
@ -99,6 +99,7 @@ from .version import __version__
|
||||
|
||||
# These submodules are separate because they are in an import cycle with
|
||||
# jax and rely on the names imported above.
|
||||
from . import dtypes
|
||||
from . import errors
|
||||
from . import image
|
||||
from . import lax
|
||||
|
373
jax/_src/dtypes.py
Normal file
373
jax/_src/dtypes.py
Normal file
@ -0,0 +1,373 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Array type functions.
|
||||
#
|
||||
# JAX dtypes differ from NumPy in both:
|
||||
# a) their type promotion rules, and
|
||||
# b) the set of supported types (e.g., bfloat16),
|
||||
# so we need our own implementation that deviates from NumPy in places.
|
||||
|
||||
|
||||
import functools
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import util
|
||||
from jax.config import flags, config
|
||||
from jax.lib import xla_client
|
||||
|
||||
from jax._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
# bfloat16 support
|
||||
bfloat16: type = xla_client.bfloat16
|
||||
_bfloat16_dtype: np.dtype = np.dtype(bfloat16)
|
||||
|
||||
# Default types.
|
||||
|
||||
bool_ = np.bool_
|
||||
int_: np.dtype = np.int64 # type: ignore
|
||||
float_: np.dtype = np.float64 # type: ignore
|
||||
complex_ = np.complex128
|
||||
|
||||
# TODO(phawkins): change the above defaults to:
|
||||
# int_ = np.int32
|
||||
# float_ = np.float32
|
||||
# complex_ = np.complex64
|
||||
|
||||
# Trivial vectorspace datatype needed for tangent values of int/bool primals
|
||||
float0 = np.dtype([('float0', np.void, 0)])
|
||||
|
||||
_dtype_to_32bit_dtype = {
|
||||
np.dtype('int64'): np.dtype('int32'),
|
||||
np.dtype('uint64'): np.dtype('uint32'),
|
||||
np.dtype('float64'): np.dtype('float32'),
|
||||
np.dtype('complex128'): np.dtype('complex64'),
|
||||
}
|
||||
|
||||
@util.memoize
|
||||
def canonicalize_dtype(dtype):
|
||||
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
|
||||
try:
|
||||
dtype = np.dtype(dtype)
|
||||
except TypeError as e:
|
||||
raise TypeError(f'dtype {dtype!r} not understood') from e
|
||||
|
||||
if config.x64_enabled:
|
||||
return dtype
|
||||
else:
|
||||
return _dtype_to_32bit_dtype.get(dtype, dtype)
|
||||
|
||||
|
||||
# Default dtypes corresponding to Python scalars.
|
||||
python_scalar_dtypes : dict = {
|
||||
bool: np.dtype(bool_),
|
||||
int: np.dtype(int_),
|
||||
float: np.dtype(float_),
|
||||
complex: np.dtype(complex_),
|
||||
}
|
||||
|
||||
def scalar_type_of(x):
|
||||
typ = dtype(x)
|
||||
if typ == bfloat16:
|
||||
return float
|
||||
elif np.issubdtype(typ, np.bool_):
|
||||
return bool
|
||||
elif np.issubdtype(typ, np.integer):
|
||||
return int
|
||||
elif np.issubdtype(typ, np.floating):
|
||||
return float
|
||||
elif np.issubdtype(typ, np.complexfloating):
|
||||
return complex
|
||||
else:
|
||||
raise TypeError("Invalid scalar value {}".format(x))
|
||||
|
||||
|
||||
def _scalar_type_to_dtype(typ: type, value: Any = None):
|
||||
"""Return the numpy dtype for the given scalar type.
|
||||
|
||||
Raises
|
||||
------
|
||||
OverflowError: if `typ` is `int` and the value is too large for int64.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> _scalar_type_to_dtype(int)
|
||||
dtype('int32')
|
||||
>>> _scalar_type_to_dtype(float)
|
||||
dtype('float32')
|
||||
>>> _scalar_type_to_dtype(complex)
|
||||
dtype('complex64')
|
||||
>>> _scalar_type_to_dtype(int)
|
||||
dtype('int32')
|
||||
>>> _scalar_type_to_dtype(int, 0)
|
||||
dtype('int32')
|
||||
>>> _scalar_type_to_dtype(int, 1 << 63) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
OverflowError: Python int 9223372036854775808 too large to convert to int32
|
||||
"""
|
||||
dtype = canonicalize_dtype(python_scalar_dtypes[typ])
|
||||
if typ is int and value is not None:
|
||||
if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max:
|
||||
raise OverflowError(f"Python int {value} too large to convert to {dtype}")
|
||||
return dtype
|
||||
|
||||
|
||||
def coerce_to_array(x, dtype=None):
|
||||
"""Coerces a scalar or NumPy array to an np.array.
|
||||
|
||||
Handles Python scalar type promotion according to JAX's rules, not NumPy's
|
||||
rules.
|
||||
"""
|
||||
if dtype is None and type(x) in python_scalar_dtypes:
|
||||
dtype = _scalar_type_to_dtype(type(x), x)
|
||||
return np.asarray(x, dtype)
|
||||
|
||||
iinfo = np.iinfo
|
||||
|
||||
class finfo(np.finfo):
|
||||
__doc__ = np.finfo.__doc__
|
||||
_finfo_cache: Dict[np.dtype, np.finfo] = {}
|
||||
@staticmethod
|
||||
def _bfloat16_finfo():
|
||||
def float_to_str(f):
|
||||
return "%12.4e" % float(f)
|
||||
|
||||
bfloat16 = _bfloat16_dtype.type
|
||||
tiny = float.fromhex("0x1p-126")
|
||||
resolution = 0.01
|
||||
eps = float.fromhex("0x1p-7")
|
||||
epsneg = float.fromhex("0x1p-8")
|
||||
max = float.fromhex("0x1.FEp127")
|
||||
|
||||
obj = object.__new__(np.finfo)
|
||||
obj.dtype = _bfloat16_dtype
|
||||
obj.bits = 16
|
||||
obj.eps = bfloat16(eps)
|
||||
obj.epsneg = bfloat16(epsneg)
|
||||
obj.machep = -7
|
||||
obj.negep = -8
|
||||
obj.max = bfloat16(max)
|
||||
obj.min = bfloat16(-max)
|
||||
obj.nexp = 8
|
||||
obj.nmant = 7
|
||||
obj.iexp = obj.nexp
|
||||
obj.precision = 2
|
||||
obj.resolution = bfloat16(resolution)
|
||||
obj.tiny = bfloat16(tiny)
|
||||
obj.machar = None # np.core.getlimits.MachArLike does not support bfloat16.
|
||||
|
||||
obj._str_tiny = float_to_str(tiny)
|
||||
obj._str_max = float_to_str(max)
|
||||
obj._str_epsneg = float_to_str(epsneg)
|
||||
obj._str_eps = float_to_str(eps)
|
||||
obj._str_resolution = float_to_str(resolution)
|
||||
return obj
|
||||
|
||||
def __new__(cls, dtype):
|
||||
if isinstance(dtype, str) and dtype == 'bfloat16' or dtype == _bfloat16_dtype:
|
||||
if _bfloat16_dtype not in cls._finfo_cache:
|
||||
cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo()
|
||||
return cls._finfo_cache[_bfloat16_dtype]
|
||||
return super().__new__(cls, dtype)
|
||||
|
||||
def _issubclass(a, b):
|
||||
"""Determines if ``a`` is a subclass of ``b``.
|
||||
|
||||
Similar to issubclass, but returns False instead of an exception if `a` is not
|
||||
a class.
|
||||
"""
|
||||
try:
|
||||
return issubclass(a, b)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
def issubdtype(a, b):
|
||||
if a == bfloat16:
|
||||
if isinstance(b, np.dtype):
|
||||
return b == _bfloat16_dtype
|
||||
else:
|
||||
return b in [bfloat16, np.floating, np.inexact, np.number]
|
||||
if not _issubclass(b, np.generic):
|
||||
# Workaround for JAX scalar types. NumPy's issubdtype has a backward
|
||||
# compatibility behavior for the second argument of issubdtype that
|
||||
# interacts badly with JAX's custom scalar types. As a workaround,
|
||||
# explicitly cast the second argument to a NumPy type object.
|
||||
b = np.dtype(b).type
|
||||
return np.issubdtype(a, b)
|
||||
|
||||
can_cast = np.can_cast
|
||||
issubsctype = np.issubsctype
|
||||
|
||||
# Return the type holding the real part of the input type
|
||||
def dtype_real(typ):
|
||||
if np.issubdtype(typ, np.complexfloating):
|
||||
if typ == np.dtype('complex64'):
|
||||
return np.dtype('float32')
|
||||
elif typ == np.dtype('complex128'):
|
||||
return np.dtype('float64')
|
||||
else:
|
||||
raise TypeError("Unknown complex floating type {}".format(typ))
|
||||
else:
|
||||
return typ
|
||||
|
||||
# Enumeration of all valid JAX types in order.
|
||||
_weak_types = [int, float, complex]
|
||||
_jax_types = [
|
||||
np.dtype('bool'),
|
||||
np.dtype('uint8'),
|
||||
np.dtype('uint16'),
|
||||
np.dtype('uint32'),
|
||||
np.dtype('uint64'),
|
||||
np.dtype('int8'),
|
||||
np.dtype('int16'),
|
||||
np.dtype('int32'),
|
||||
np.dtype('int64'),
|
||||
np.dtype(bfloat16),
|
||||
np.dtype('float16'),
|
||||
np.dtype('float32'),
|
||||
np.dtype('float64'),
|
||||
np.dtype('complex64'),
|
||||
np.dtype('complex128'),
|
||||
] + _weak_types # type: ignore[operator]
|
||||
|
||||
def _jax_type(dtype, weak_type):
|
||||
"""Return the jax type for a dtype and weak type."""
|
||||
return type(dtype.type(0).item()) if (weak_type and dtype != bool) else dtype
|
||||
|
||||
def _dtype_and_weaktype(value):
|
||||
"""Return a (dtype, weak_type) tuple for the given input."""
|
||||
return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
|
||||
|
||||
def _type_promotion_lattice():
|
||||
"""
|
||||
Return the type promotion lattice in the form of a DAG.
|
||||
This DAG maps each type to its immediately higher type on the lattice.
|
||||
"""
|
||||
b1, u1, u2, u4, u8, i1, i2, i4, i8, bf, f2, f4, f8, c4, c8, i_, f_, c_ = _jax_types
|
||||
return {
|
||||
b1: [i_],
|
||||
u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
|
||||
i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
|
||||
f_: [bf, f2, c_], bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
|
||||
c_: [c4], c4: [c8], c8: [],
|
||||
}
|
||||
|
||||
def _make_lattice_upper_bounds():
|
||||
lattice = _type_promotion_lattice()
|
||||
upper_bounds = {node: {node} for node in lattice}
|
||||
for n in lattice:
|
||||
while True:
|
||||
new_upper_bounds = set().union(*(lattice[b] for b in upper_bounds[n]))
|
||||
if n in new_upper_bounds:
|
||||
raise ValueError(f"cycle detected in type promotion lattice for node {n}")
|
||||
if new_upper_bounds.issubset(upper_bounds[n]):
|
||||
break
|
||||
upper_bounds[n] |= new_upper_bounds
|
||||
return upper_bounds
|
||||
_lattice_upper_bounds = _make_lattice_upper_bounds()
|
||||
|
||||
@functools.lru_cache(512) # don't use util.memoize because there is no X64 dependence.
|
||||
def _least_upper_bound(*nodes):
|
||||
"""Compute the least upper bound of a set of nodes.
|
||||
|
||||
Args:
|
||||
nodes: sequence of entries from _jax_types
|
||||
Returns:
|
||||
the _jax_type representing the least upper bound of the input nodes
|
||||
on the promotion lattice.
|
||||
"""
|
||||
# This function computes the least upper bound of a set of nodes N within a partially
|
||||
# ordered set defined by the lattice generated above.
|
||||
# Given a partially ordered set S, let the set of upper bounds of n ∈ S be
|
||||
# UB(n) ≡ {m ∈ S | n ≤ m}
|
||||
# Further, for a set of nodes N ⊆ S, let the set of common upper bounds be given by
|
||||
# CUB(N) ≡ {a ∈ S | ∀ b ∈ N: a ∈ UB(b)}
|
||||
# Then the least upper bound of N is defined as
|
||||
# LUB(N) ≡ {c ∈ CUB(N) | ∀ d ∈ CUB(N), c ≤ d}
|
||||
# The definition of an upper bound implies that c ≤ d if and only if d ∈ UB(c),
|
||||
# so the LUB can be expressed:
|
||||
# LUB(N) = {c ∈ CUB(N) | ∀ d ∈ CUB(N): d ∈ UB(c)}
|
||||
# or, equivalently:
|
||||
# LUB(N) = {c ∈ CUB(N) | CUB(N) ⊆ UB(c)}
|
||||
# By definition, LUB(N) has a cardinality of 1 for a partially ordered set.
|
||||
# Note a potential algorithmic shortcut: from the definition of CUB(N), we have
|
||||
# ∀ c ∈ N: CUB(N) ⊆ UB(c)
|
||||
# So if N ∩ CUB(N) is nonempty, if follows that LUB(N) = N ∩ CUB(N).
|
||||
N = set(nodes)
|
||||
UB = _lattice_upper_bounds
|
||||
CUB = set.intersection(*(UB[n] for n in N))
|
||||
LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])}
|
||||
if len(LUB) == 1:
|
||||
return LUB.pop()
|
||||
else:
|
||||
raise ValueError(f"{nodes} do not have a unique least upper bound.")
|
||||
|
||||
def promote_types(a, b):
|
||||
"""Returns the type to which a binary operation should cast its arguments.
|
||||
|
||||
For details of JAX's type promotion semantics, see :ref:`type-promotion`.
|
||||
|
||||
Args:
|
||||
a: a :class:`numpy.dtype` or a dtype specifier.
|
||||
b: a :class:`numpy.dtype` or a dtype specifier.
|
||||
|
||||
Returns:
|
||||
A :class:`numpy.dtype` object.
|
||||
"""
|
||||
a = a if any(a is t for t in _weak_types) else np.dtype(a)
|
||||
b = b if any(b is t for t in _weak_types) else np.dtype(b)
|
||||
return np.dtype(_least_upper_bound(a, b))
|
||||
|
||||
def is_weakly_typed(x):
|
||||
try:
|
||||
return x.aval.weak_type
|
||||
except AttributeError:
|
||||
return type(x) in _weak_types
|
||||
|
||||
def is_python_scalar(x):
|
||||
try:
|
||||
return x.aval.weak_type and np.ndim(x) == 0
|
||||
except AttributeError:
|
||||
return type(x) in python_scalar_dtypes
|
||||
|
||||
def dtype(x):
|
||||
if type(x) in python_scalar_dtypes:
|
||||
return python_scalar_dtypes[type(x)]
|
||||
return np.result_type(x)
|
||||
|
||||
def _lattice_result_type(*args):
|
||||
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
|
||||
if len(dtypes) == 1:
|
||||
return dtypes[0], weak_types[0]
|
||||
|
||||
# If all inputs are weakly typed, we compute the bound of the strongly-typed
|
||||
# counterparts and apply the weak type at the end. This avoids returning the
|
||||
# incorrect result with non-canonical weak types (e.g. weak int16).
|
||||
if all(weak_types):
|
||||
result_type = _least_upper_bound(*{_jax_type(dtype, False) for dtype in dtypes})
|
||||
return dtype(result_type), True
|
||||
else:
|
||||
result_type = _least_upper_bound(*{_jax_type(d, w) for d, w in zip(dtypes, weak_types)})
|
||||
return dtype(result_type), any(result_type is t for t in _weak_types)
|
||||
|
||||
def result_type(*args):
|
||||
"""Convenience function to apply JAX argument dtype promotion."""
|
||||
if len(args) == 0:
|
||||
raise ValueError("at least one array or dtype is required")
|
||||
return canonicalize_dtype(_lattice_result_type(*args)[0])
|
@ -30,7 +30,7 @@ import numpy as np
|
||||
import jax
|
||||
from jax import api
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.lax import lax
|
||||
|
@ -21,7 +21,8 @@ from jax.api import jit, linear_transpose, ShapeDtypeStruct
|
||||
from jax.core import Primitive
|
||||
from jax.interpreters import xla
|
||||
from jax._src.util import prod
|
||||
from jax import dtypes, lax
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax.lib import xla_client
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
|
@ -32,7 +32,7 @@ from jax import ad_util
|
||||
from jax import api
|
||||
from jax import api_util
|
||||
from jax import linear_util as lu
|
||||
from jax import dtypes
|
||||
from jax._src import dtypes
|
||||
from jax import tree_util
|
||||
from jax.config import flags, config
|
||||
from jax.core import (Primitive, _canonicalize_dimension, UnshapedArray,
|
||||
|
@ -21,7 +21,7 @@ from jax import ad_util
|
||||
from jax import api
|
||||
from jax import lax
|
||||
from jax import ops
|
||||
from jax import dtypes
|
||||
from jax._src import dtypes
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
|
@ -23,7 +23,7 @@ from typing import Union
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax._src import dtypes
|
||||
from jax import tree_util
|
||||
from . import lax
|
||||
from jax.core import ShapedArray, AxisName, raise_to_shaped
|
||||
|
@ -20,7 +20,7 @@ import numpy as np
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
from jax import custom_jvp
|
||||
from jax import dtypes
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax import core
|
||||
from jax.core import AxisName
|
||||
|
@ -41,7 +41,7 @@ from jax import jit, custom_jvp
|
||||
from .vectorize import vectorize
|
||||
from .util import _wraps
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax._src import dtypes
|
||||
from jax import errors
|
||||
from jax.core import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape
|
||||
from jax.config import config
|
||||
|
@ -24,7 +24,7 @@ from jax import jit, custom_jvp
|
||||
from jax import lax
|
||||
from jax import ops
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax import dtypes
|
||||
from jax._src import dtypes
|
||||
from .util import _wraps
|
||||
from . import lax_numpy as jnp
|
||||
from jax._src.util import canonicalize_axis
|
||||
|
@ -22,7 +22,7 @@ import numpy as np
|
||||
from jax import lax
|
||||
from jax import core
|
||||
from jax import numpy as jnp
|
||||
from jax import dtypes
|
||||
from jax._src import dtypes
|
||||
from jax.core import NamedShape
|
||||
from jax.api import jit, vmap
|
||||
from jax._src.numpy.lax_numpy import _constant_like, _convert_and_clip_integer, asarray
|
||||
|
@ -18,7 +18,7 @@ import numpy as np
|
||||
|
||||
from . import ad_util
|
||||
from . import core
|
||||
from . import dtypes
|
||||
from ._src import dtypes
|
||||
|
||||
from ._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
@ -40,7 +40,7 @@ from . import core
|
||||
from . import lib
|
||||
from . import linear_util as lu
|
||||
from . import ad_util
|
||||
from . import dtypes
|
||||
from ._src import dtypes
|
||||
from .core import eval_jaxpr
|
||||
from .api_util import (flatten_fun, apply_flat_fun, flatten_fun_nokwargs,
|
||||
flatten_fun_nokwargs2, argnums_partial,
|
||||
|
@ -18,7 +18,7 @@ from typing import Any, Dict, Iterable, Tuple, Union
|
||||
import numpy as np
|
||||
|
||||
from . import core
|
||||
from . import dtypes
|
||||
from ._src import dtypes
|
||||
from .tree_util import (tree_flatten, tree_unflatten, tree_multimap,
|
||||
tree_structure, treedef_children, treedef_is_leaf)
|
||||
from ._src.tree_util import _replace_nones
|
||||
|
@ -28,7 +28,7 @@ from typing import (Any, Callable, ClassVar, Dict, Generator,
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import dtypes
|
||||
from ._src import dtypes
|
||||
from .config import FLAGS, config
|
||||
from .errors import (ConcretizationTypeError, TracerArrayConversionError,
|
||||
TracerIntegerConversionError)
|
||||
|
@ -19,7 +19,7 @@ import operator as op
|
||||
from typing import Callable, Sequence, Tuple, Any
|
||||
|
||||
from . import core
|
||||
from . import dtypes
|
||||
from ._src import dtypes
|
||||
from . import linear_util as lu
|
||||
from .tree_util import (tree_flatten, tree_unflatten, tree_map,
|
||||
tree_multimap, treedef_is_leaf, treedef_tuple,
|
||||
|
371
jax/dtypes.py
371
jax/dtypes.py
@ -12,362 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Array type functions.
|
||||
#
|
||||
# JAX dtypes differ from NumPy in both:
|
||||
# a) their type promotion rules, and
|
||||
# b) the set of supported types (e.g., bfloat16),
|
||||
# so we need our own implementation that deviates from NumPy in places.
|
||||
|
||||
|
||||
import functools
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._src import util
|
||||
from .config import flags, config
|
||||
from .lib import xla_client
|
||||
|
||||
from ._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
# bfloat16 support
|
||||
bfloat16: type = xla_client.bfloat16
|
||||
_bfloat16_dtype: np.dtype = np.dtype(bfloat16)
|
||||
|
||||
# Default types.
|
||||
|
||||
bool_ = np.bool_
|
||||
int_: np.dtype = np.int64 # type: ignore
|
||||
float_: np.dtype = np.float64 # type: ignore
|
||||
complex_ = np.complex128
|
||||
|
||||
# TODO(phawkins): change the above defaults to:
|
||||
# int_ = np.int32
|
||||
# float_ = np.float32
|
||||
# complex_ = np.complex64
|
||||
|
||||
# Trivial vectorspace datatype needed for tangent values of int/bool primals
|
||||
float0 = np.dtype([('float0', np.void, 0)])
|
||||
|
||||
_dtype_to_32bit_dtype = {
|
||||
np.dtype('int64'): np.dtype('int32'),
|
||||
np.dtype('uint64'): np.dtype('uint32'),
|
||||
np.dtype('float64'): np.dtype('float32'),
|
||||
np.dtype('complex128'): np.dtype('complex64'),
|
||||
}
|
||||
|
||||
@util.memoize
|
||||
def canonicalize_dtype(dtype):
|
||||
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
|
||||
try:
|
||||
dtype = np.dtype(dtype)
|
||||
except TypeError as e:
|
||||
raise TypeError(f'dtype {dtype!r} not understood') from e
|
||||
|
||||
if config.x64_enabled:
|
||||
return dtype
|
||||
else:
|
||||
return _dtype_to_32bit_dtype.get(dtype, dtype)
|
||||
|
||||
|
||||
# Default dtypes corresponding to Python scalars.
|
||||
python_scalar_dtypes : dict = {
|
||||
bool: np.dtype(bool_),
|
||||
int: np.dtype(int_),
|
||||
float: np.dtype(float_),
|
||||
complex: np.dtype(complex_),
|
||||
}
|
||||
|
||||
def scalar_type_of(x):
|
||||
typ = dtype(x)
|
||||
if typ == bfloat16:
|
||||
return float
|
||||
elif np.issubdtype(typ, np.bool_):
|
||||
return bool
|
||||
elif np.issubdtype(typ, np.integer):
|
||||
return int
|
||||
elif np.issubdtype(typ, np.floating):
|
||||
return float
|
||||
elif np.issubdtype(typ, np.complexfloating):
|
||||
return complex
|
||||
else:
|
||||
raise TypeError("Invalid scalar value {}".format(x))
|
||||
|
||||
|
||||
def _scalar_type_to_dtype(typ: type, value: Any = None):
|
||||
"""Return the numpy dtype for the given scalar type.
|
||||
|
||||
Raises
|
||||
------
|
||||
OverflowError: if `typ` is `int` and the value is too large for int64.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> _scalar_type_to_dtype(int)
|
||||
dtype('int32')
|
||||
>>> _scalar_type_to_dtype(float)
|
||||
dtype('float32')
|
||||
>>> _scalar_type_to_dtype(complex)
|
||||
dtype('complex64')
|
||||
>>> _scalar_type_to_dtype(int)
|
||||
dtype('int32')
|
||||
>>> _scalar_type_to_dtype(int, 0)
|
||||
dtype('int32')
|
||||
>>> _scalar_type_to_dtype(int, 1 << 63) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
OverflowError: Python int 9223372036854775808 too large to convert to int32
|
||||
"""
|
||||
dtype = canonicalize_dtype(python_scalar_dtypes[typ])
|
||||
if typ is int and value is not None:
|
||||
if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max:
|
||||
raise OverflowError(f"Python int {value} too large to convert to {dtype}")
|
||||
return dtype
|
||||
|
||||
|
||||
def coerce_to_array(x, dtype=None):
|
||||
"""Coerces a scalar or NumPy array to an np.array.
|
||||
|
||||
Handles Python scalar type promotion according to JAX's rules, not NumPy's
|
||||
rules.
|
||||
"""
|
||||
if dtype is None and type(x) in python_scalar_dtypes:
|
||||
dtype = _scalar_type_to_dtype(type(x), x)
|
||||
return np.asarray(x, dtype)
|
||||
|
||||
iinfo = np.iinfo
|
||||
|
||||
class finfo(np.finfo):
|
||||
__doc__ = np.finfo.__doc__
|
||||
_finfo_cache: Dict[np.dtype, np.finfo] = {}
|
||||
@staticmethod
|
||||
def _bfloat16_finfo():
|
||||
def float_to_str(f):
|
||||
return "%12.4e" % float(f)
|
||||
|
||||
bfloat16 = _bfloat16_dtype.type
|
||||
tiny = float.fromhex("0x1p-126")
|
||||
resolution = 0.01
|
||||
eps = float.fromhex("0x1p-7")
|
||||
epsneg = float.fromhex("0x1p-8")
|
||||
max = float.fromhex("0x1.FEp127")
|
||||
|
||||
obj = object.__new__(np.finfo)
|
||||
obj.dtype = _bfloat16_dtype
|
||||
obj.bits = 16
|
||||
obj.eps = bfloat16(eps)
|
||||
obj.epsneg = bfloat16(epsneg)
|
||||
obj.machep = -7
|
||||
obj.negep = -8
|
||||
obj.max = bfloat16(max)
|
||||
obj.min = bfloat16(-max)
|
||||
obj.nexp = 8
|
||||
obj.nmant = 7
|
||||
obj.iexp = obj.nexp
|
||||
obj.precision = 2
|
||||
obj.resolution = bfloat16(resolution)
|
||||
obj.tiny = bfloat16(tiny)
|
||||
obj.machar = None # np.core.getlimits.MachArLike does not support bfloat16.
|
||||
|
||||
obj._str_tiny = float_to_str(tiny)
|
||||
obj._str_max = float_to_str(max)
|
||||
obj._str_epsneg = float_to_str(epsneg)
|
||||
obj._str_eps = float_to_str(eps)
|
||||
obj._str_resolution = float_to_str(resolution)
|
||||
return obj
|
||||
|
||||
def __new__(cls, dtype):
|
||||
if isinstance(dtype, str) and dtype == 'bfloat16' or dtype == _bfloat16_dtype:
|
||||
if _bfloat16_dtype not in cls._finfo_cache:
|
||||
cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo()
|
||||
return cls._finfo_cache[_bfloat16_dtype]
|
||||
return super().__new__(cls, dtype)
|
||||
|
||||
def _issubclass(a, b):
|
||||
"""Determines if ``a`` is a subclass of ``b``.
|
||||
|
||||
Similar to issubclass, but returns False instead of an exception if `a` is not
|
||||
a class.
|
||||
"""
|
||||
try:
|
||||
return issubclass(a, b)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
def issubdtype(a, b):
|
||||
if a == bfloat16:
|
||||
if isinstance(b, np.dtype):
|
||||
return b == _bfloat16_dtype
|
||||
else:
|
||||
return b in [bfloat16, np.floating, np.inexact, np.number]
|
||||
if not _issubclass(b, np.generic):
|
||||
# Workaround for JAX scalar types. NumPy's issubdtype has a backward
|
||||
# compatibility behavior for the second argument of issubdtype that
|
||||
# interacts badly with JAX's custom scalar types. As a workaround,
|
||||
# explicitly cast the second argument to a NumPy type object.
|
||||
b = np.dtype(b).type
|
||||
return np.issubdtype(a, b)
|
||||
|
||||
can_cast = np.can_cast
|
||||
issubsctype = np.issubsctype
|
||||
|
||||
# Return the type holding the real part of the input type
|
||||
def dtype_real(typ):
|
||||
if np.issubdtype(typ, np.complexfloating):
|
||||
if typ == np.dtype('complex64'):
|
||||
return np.dtype('float32')
|
||||
elif typ == np.dtype('complex128'):
|
||||
return np.dtype('float64')
|
||||
else:
|
||||
raise TypeError("Unknown complex floating type {}".format(typ))
|
||||
else:
|
||||
return typ
|
||||
|
||||
# Enumeration of all valid JAX types in order.
|
||||
_weak_types = [int, float, complex]
|
||||
_jax_types = [
|
||||
np.dtype('bool'),
|
||||
np.dtype('uint8'),
|
||||
np.dtype('uint16'),
|
||||
np.dtype('uint32'),
|
||||
np.dtype('uint64'),
|
||||
np.dtype('int8'),
|
||||
np.dtype('int16'),
|
||||
np.dtype('int32'),
|
||||
np.dtype('int64'),
|
||||
np.dtype(bfloat16),
|
||||
np.dtype('float16'),
|
||||
np.dtype('float32'),
|
||||
np.dtype('float64'),
|
||||
np.dtype('complex64'),
|
||||
np.dtype('complex128'),
|
||||
] + _weak_types # type: ignore[operator]
|
||||
|
||||
def _jax_type(dtype, weak_type):
|
||||
"""Return the jax type for a dtype and weak type."""
|
||||
return type(dtype.type(0).item()) if (weak_type and dtype != bool) else dtype
|
||||
|
||||
def _dtype_and_weaktype(value):
|
||||
"""Return a (dtype, weak_type) tuple for the given input."""
|
||||
return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
|
||||
|
||||
def _type_promotion_lattice():
|
||||
"""
|
||||
Return the type promotion lattice in the form of a DAG.
|
||||
This DAG maps each type to its immediately higher type on the lattice.
|
||||
"""
|
||||
b1, u1, u2, u4, u8, i1, i2, i4, i8, bf, f2, f4, f8, c4, c8, i_, f_, c_ = _jax_types
|
||||
return {
|
||||
b1: [i_],
|
||||
u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
|
||||
i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
|
||||
f_: [bf, f2, c_], bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
|
||||
c_: [c4], c4: [c8], c8: [],
|
||||
}
|
||||
|
||||
def _make_lattice_upper_bounds():
|
||||
lattice = _type_promotion_lattice()
|
||||
upper_bounds = {node: {node} for node in lattice}
|
||||
for n in lattice:
|
||||
while True:
|
||||
new_upper_bounds = set().union(*(lattice[b] for b in upper_bounds[n]))
|
||||
if n in new_upper_bounds:
|
||||
raise ValueError(f"cycle detected in type promotion lattice for node {n}")
|
||||
if new_upper_bounds.issubset(upper_bounds[n]):
|
||||
break
|
||||
upper_bounds[n] |= new_upper_bounds
|
||||
return upper_bounds
|
||||
_lattice_upper_bounds = _make_lattice_upper_bounds()
|
||||
|
||||
@functools.lru_cache(512) # don't use util.memoize because there is no X64 dependence.
|
||||
def _least_upper_bound(*nodes):
|
||||
"""Compute the least upper bound of a set of nodes.
|
||||
|
||||
Args:
|
||||
nodes: sequence of entries from _jax_types
|
||||
Returns:
|
||||
the _jax_type representing the least upper bound of the input nodes
|
||||
on the promotion lattice.
|
||||
"""
|
||||
# This function computes the least upper bound of a set of nodes N within a partially
|
||||
# ordered set defined by the lattice generated above.
|
||||
# Given a partially ordered set S, let the set of upper bounds of n ∈ S be
|
||||
# UB(n) ≡ {m ∈ S | n ≤ m}
|
||||
# Further, for a set of nodes N ⊆ S, let the set of common upper bounds be given by
|
||||
# CUB(N) ≡ {a ∈ S | ∀ b ∈ N: a ∈ UB(b)}
|
||||
# Then the least upper bound of N is defined as
|
||||
# LUB(N) ≡ {c ∈ CUB(N) | ∀ d ∈ CUB(N), c ≤ d}
|
||||
# The definition of an upper bound implies that c ≤ d if and only if d ∈ UB(c),
|
||||
# so the LUB can be expressed:
|
||||
# LUB(N) = {c ∈ CUB(N) | ∀ d ∈ CUB(N): d ∈ UB(c)}
|
||||
# or, equivalently:
|
||||
# LUB(N) = {c ∈ CUB(N) | CUB(N) ⊆ UB(c)}
|
||||
# By definition, LUB(N) has a cardinality of 1 for a partially ordered set.
|
||||
# Note a potential algorithmic shortcut: from the definition of CUB(N), we have
|
||||
# ∀ c ∈ N: CUB(N) ⊆ UB(c)
|
||||
# So if N ∩ CUB(N) is nonempty, if follows that LUB(N) = N ∩ CUB(N).
|
||||
N = set(nodes)
|
||||
UB = _lattice_upper_bounds
|
||||
CUB = set.intersection(*(UB[n] for n in N))
|
||||
LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])}
|
||||
if len(LUB) == 1:
|
||||
return LUB.pop()
|
||||
else:
|
||||
raise ValueError(f"{nodes} do not have a unique least upper bound.")
|
||||
|
||||
def promote_types(a, b):
|
||||
"""Returns the type to which a binary operation should cast its arguments.
|
||||
|
||||
For details of JAX's type promotion semantics, see :ref:`type-promotion`.
|
||||
|
||||
Args:
|
||||
a: a :class:`numpy.dtype` or a dtype specifier.
|
||||
b: a :class:`numpy.dtype` or a dtype specifier.
|
||||
|
||||
Returns:
|
||||
A :class:`numpy.dtype` object.
|
||||
"""
|
||||
a = a if any(a is t for t in _weak_types) else np.dtype(a)
|
||||
b = b if any(b is t for t in _weak_types) else np.dtype(b)
|
||||
return np.dtype(_least_upper_bound(a, b))
|
||||
|
||||
def is_weakly_typed(x):
|
||||
try:
|
||||
return x.aval.weak_type
|
||||
except AttributeError:
|
||||
return type(x) in _weak_types
|
||||
|
||||
def is_python_scalar(x):
|
||||
try:
|
||||
return x.aval.weak_type and np.ndim(x) == 0
|
||||
except AttributeError:
|
||||
return type(x) in python_scalar_dtypes
|
||||
|
||||
def dtype(x):
|
||||
if type(x) in python_scalar_dtypes:
|
||||
return python_scalar_dtypes[type(x)]
|
||||
return np.result_type(x)
|
||||
|
||||
def _lattice_result_type(*args):
|
||||
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
|
||||
if len(dtypes) == 1:
|
||||
return dtypes[0], weak_types[0]
|
||||
|
||||
# If all inputs are weakly typed, we compute the bound of the strongly-typed
|
||||
# counterparts and apply the weak type at the end. This avoids returning the
|
||||
# incorrect result with non-canonical weak types (e.g. weak int16).
|
||||
if all(weak_types):
|
||||
result_type = _least_upper_bound(*{_jax_type(dtype, False) for dtype in dtypes})
|
||||
return dtype(result_type), True
|
||||
else:
|
||||
result_type = _least_upper_bound(*{_jax_type(d, w) for d, w in zip(dtypes, weak_types)})
|
||||
return dtype(result_type), any(result_type is t for t in _weak_types)
|
||||
|
||||
def result_type(*args):
|
||||
"""Convenience function to apply JAX argument dtype promotion."""
|
||||
if len(args) == 0:
|
||||
raise ValueError("at least one array or dtype is required")
|
||||
return canonicalize_dtype(_lattice_result_type(*args)[0])
|
||||
# flake8: noqa: F401
|
||||
from jax._src.dtypes import (
|
||||
_jax_types, # TODO(phawkins): fix users and remove?
|
||||
bfloat16,
|
||||
canonicalize_dtype,
|
||||
finfo, # TODO(phawkins): switch callers to jnp.finfo?
|
||||
float0,
|
||||
iinfo, # TODO(phawkins): switch callers to jnp.iinfo?
|
||||
issubdtype, # TODO(phawkins): switch callers to jnp.issubdtype?
|
||||
result_type,
|
||||
scalar_type_of,
|
||||
)
|
||||
|
@ -19,7 +19,7 @@ from typing import (Tuple, List, Sequence, Set, Dict, Any, Callable, Union,
|
||||
Optional)
|
||||
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax._src import dtypes
|
||||
from jax.core import Var, Literal, Atom, Tracer
|
||||
from jax._src.util import (safe_zip, safe_map, curry, unzip2, split_list,
|
||||
tuple_delete)
|
||||
|
@ -352,7 +352,7 @@ from jax import api
|
||||
from jax import core
|
||||
from jax.config import config
|
||||
from jax import custom_derivatives
|
||||
from jax import dtypes
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax.lib import pytree
|
||||
from jax.lib import xla_client
|
||||
|
@ -17,7 +17,7 @@ import itertools
|
||||
import numpy as np
|
||||
from typing import Any, Callable, Optional, Sequence, Union
|
||||
|
||||
from jax import dtypes
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
|
||||
|
@ -20,7 +20,7 @@ from .tree_util import tree_flatten, tree_unflatten
|
||||
from ._src.util import safe_zip, unzip2
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import dtypes
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
|
||||
zip = safe_zip
|
||||
|
@ -20,7 +20,7 @@ from typing import Any, Callable, Dict
|
||||
from . import partial_eval as pe
|
||||
from ..config import config
|
||||
from .. import core
|
||||
from ..dtypes import dtype, float0
|
||||
from .._src.dtypes import dtype, float0
|
||||
from ..core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
|
||||
raise_to_shaped)
|
||||
from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval,
|
||||
|
@ -22,7 +22,8 @@ from typing import Callable, Dict, Optional, Sequence, Union, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import core, dtypes
|
||||
from .. import core
|
||||
from .._src import dtypes
|
||||
from ..tree_util import tree_unflatten
|
||||
from ..core import ShapedArray, Trace, Tracer
|
||||
from .._src.util import safe_map, safe_zip, unzip2, prod, wrap_name
|
||||
|
@ -23,7 +23,7 @@ from weakref import ref
|
||||
import numpy as np
|
||||
|
||||
from .. import core
|
||||
from .. import dtypes
|
||||
from .._src import dtypes
|
||||
from .. import linear_util as lu
|
||||
from ..ad_util import Zero
|
||||
from .._src.util import (unzip2, safe_zip, safe_map, toposort, partial,
|
||||
|
@ -27,7 +27,7 @@ import numpy as np
|
||||
from ..config import config
|
||||
from .. import core
|
||||
from .. import ad_util
|
||||
from .. import dtypes
|
||||
from jax._src import dtypes
|
||||
from .. import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from ..abstract_arrays import (make_shaped_array, array_types)
|
||||
|
@ -21,7 +21,7 @@ import numpy as np
|
||||
import opt_einsum
|
||||
import scipy.special
|
||||
|
||||
from . import dtypes
|
||||
from jax._src import dtypes
|
||||
|
||||
_slice = builtins.slice
|
||||
_max = builtins.max
|
||||
|
@ -30,7 +30,7 @@ logging._warn_preinit_stderr = 0
|
||||
|
||||
from ..config import flags
|
||||
from jax._src import util, traceback_util
|
||||
from .. import dtypes
|
||||
from jax._src import dtypes
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
|
@ -30,7 +30,7 @@ import numpy.random as npr
|
||||
|
||||
from . import api
|
||||
from . import core
|
||||
from . import dtypes as _dtypes
|
||||
from ._src import dtypes as _dtypes
|
||||
from . import lax
|
||||
from .config import flags, bool_env, config
|
||||
from ._src.util import partial, prod
|
||||
|
@ -23,7 +23,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import dtypes
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
|
@ -40,7 +40,7 @@ from jax import lax
|
||||
from jax import linear_util
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax import dtypes
|
||||
from jax._src import dtypes
|
||||
from jax import tree_util
|
||||
from jax.interpreters import partial_eval, xla
|
||||
from jax.test_util import check_grads
|
||||
|
@ -27,7 +27,7 @@ import numpy as np
|
||||
import jax
|
||||
from jax import api
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax import tree_util
|
||||
|
@ -22,7 +22,6 @@ from jax import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax._src.tree_util import _process_pytree
|
||||
from jax import flatten_util
|
||||
from jax import dtypes
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
@ -301,7 +300,7 @@ class RavelUtilTest(jtu.JaxTestCase):
|
||||
tree = [jnp.array([3], jnp.int32),
|
||||
jnp.array([[1., 2.], [3., 4.]], jnp.float32)]
|
||||
raveled, unravel = flatten_util.ravel_pytree(tree)
|
||||
self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.float32, jnp.int32))
|
||||
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.float32, jnp.int32))
|
||||
tree_ = unravel(raveled)
|
||||
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
||||
|
||||
@ -309,7 +308,7 @@ class RavelUtilTest(jtu.JaxTestCase):
|
||||
tree = [jnp.array([0], jnp.bool_),
|
||||
jnp.array([[1, 2], [3, 4]], jnp.int32)]
|
||||
raveled, unravel = flatten_util.ravel_pytree(tree)
|
||||
self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.bool_, jnp.int32))
|
||||
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.bool_, jnp.int32))
|
||||
tree_ = unravel(raveled)
|
||||
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
||||
|
||||
@ -317,7 +316,7 @@ class RavelUtilTest(jtu.JaxTestCase):
|
||||
tree = [jnp.array([1.], jnp.float32),
|
||||
jnp.array([[1, 2 + 3j], [3, 4]], jnp.complex64)]
|
||||
raveled, unravel = flatten_util.ravel_pytree(tree)
|
||||
self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.float32, jnp.complex64))
|
||||
self.assertEqual(raveled.dtype, jnp.promote_types(jnp.float32, jnp.complex64))
|
||||
tree_ = unravel(raveled)
|
||||
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user