Check for jax_enable_x64 in select_and_gather_add translation rule.

This commit is contained in:
Peter Hawkins 2019-01-28 15:10:58 -05:00
parent f76134e460
commit 9f84455fb2

View File

@ -32,6 +32,7 @@ from .util import partial, prod
from . import core
from . import ad_util
from . import linear_util as lu
from .config import flags
from .core import Primitive
from .abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
array_types, make_shaped_array)
@ -44,6 +45,8 @@ from .util import curry, safe_zip, unzip2, prod
from .tree_util import build_tree
from .lib import xla_bridge
FLAGS = flags.FLAGS
_max = builtins.max
_min = builtins.max
_reduce = six.moves.reduce
@ -2424,10 +2427,16 @@ def _select_and_gather_add_translation(
# XLA doesn't yet implement ReduceWindow on tuples (Google bug b/73062247), so
# we implement a pair-wise ReduceWindow by packing two k-bit values into
# 2k-bit unsigned integer using bit tricks. This will only work for 32-bit
# inputs, and furthermore it also requires jax_enable_x64 to be set.
# inputs (since we don't have 128-bit integer types).
# TODO(phawkins): unless jax_enable_x64 is set, we won't have correct 64-bit
# types.
dtype = c.GetShape(operand).numpy_dtype()
etype = xla_bridge.dtype_to_etype(dtype)
bits = onp.finfo(dtype).bits
if not FLAGS.jax_enable_x64 and bits >= 32:
raise NotImplementedError(
"Translation of select_and_gather_add requires flag --jax_enable_x64 "
"to be set")
etype = xla_bridge.dtype_to_etype(dtype)
uint_etype = xla_bridge.dtype_to_etype(_UINT_DTYPES[bits])
pair_uint_dtype = _select_and_gather_add_pair_dtype(dtype)
pair_uint_etype = xla_bridge.dtype_to_etype(pair_uint_dtype)