mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Check for jax_enable_x64 in select_and_gather_add translation rule.
This commit is contained in:
parent
f76134e460
commit
9f84455fb2
13
jax/lax.py
13
jax/lax.py
@ -32,6 +32,7 @@ from .util import partial, prod
|
||||
from . import core
|
||||
from . import ad_util
|
||||
from . import linear_util as lu
|
||||
from .config import flags
|
||||
from .core import Primitive
|
||||
from .abstract_arrays import (UnshapedArray, ShapedArray, ConcreteArray,
|
||||
array_types, make_shaped_array)
|
||||
@ -44,6 +45,8 @@ from .util import curry, safe_zip, unzip2, prod
|
||||
from .tree_util import build_tree
|
||||
from .lib import xla_bridge
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
_max = builtins.max
|
||||
_min = builtins.max
|
||||
_reduce = six.moves.reduce
|
||||
@ -2424,10 +2427,16 @@ def _select_and_gather_add_translation(
|
||||
# XLA doesn't yet implement ReduceWindow on tuples (Google bug b/73062247), so
|
||||
# we implement a pair-wise ReduceWindow by packing two k-bit values into
|
||||
# 2k-bit unsigned integer using bit tricks. This will only work for 32-bit
|
||||
# inputs, and furthermore it also requires jax_enable_x64 to be set.
|
||||
# inputs (since we don't have 128-bit integer types).
|
||||
# TODO(phawkins): unless jax_enable_x64 is set, we won't have correct 64-bit
|
||||
# types.
|
||||
dtype = c.GetShape(operand).numpy_dtype()
|
||||
etype = xla_bridge.dtype_to_etype(dtype)
|
||||
bits = onp.finfo(dtype).bits
|
||||
if not FLAGS.jax_enable_x64 and bits >= 32:
|
||||
raise NotImplementedError(
|
||||
"Translation of select_and_gather_add requires flag --jax_enable_x64 "
|
||||
"to be set")
|
||||
etype = xla_bridge.dtype_to_etype(dtype)
|
||||
uint_etype = xla_bridge.dtype_to_etype(_UINT_DTYPES[bits])
|
||||
pair_uint_dtype = _select_and_gather_add_pair_dtype(dtype)
|
||||
pair_uint_etype = xla_bridge.dtype_to_etype(pair_uint_dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user