mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
segment_max: fix identity for boolean dtype
This commit is contained in:
parent
4d14899940
commit
4f6f4e5554
@ -115,7 +115,6 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
|
||||
|
||||
|
||||
|
||||
|
||||
def _get_identity(op, dtype):
|
||||
"""Get an appropriate identity for a given operation in a given dtype."""
|
||||
if op is lax.scatter_add:
|
||||
@ -123,11 +122,15 @@ def _get_identity(op, dtype):
|
||||
elif op is lax.scatter_mul:
|
||||
return 1
|
||||
elif op is lax.scatter_min:
|
||||
if jnp.issubdtype(dtype, jnp.integer):
|
||||
if dtype == dtypes.bool_:
|
||||
return True
|
||||
elif jnp.issubdtype(dtype, jnp.integer):
|
||||
return jnp.iinfo(dtype).max
|
||||
return float('inf')
|
||||
elif op is lax.scatter_max:
|
||||
if jnp.issubdtype(dtype, jnp.integer):
|
||||
if dtype == dtypes.bool_:
|
||||
return False
|
||||
elif jnp.issubdtype(dtype, jnp.integer):
|
||||
return jnp.iinfo(dtype).min
|
||||
return -float('inf')
|
||||
else:
|
||||
|
@ -1229,6 +1229,49 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(grad, np.array([0., 0.], np.float32))
|
||||
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list({
|
||||
"testcase_name": "_{}_{}_num_segments={}_bucket_size={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype),
|
||||
reducer.__name__, num_segments, bucket_size),
|
||||
"dtype": dtype, "shape": shape,
|
||||
"reducer": reducer, "op": op, "identity": identity,
|
||||
"num_segments": num_segments, "bucket_size": bucket_size}
|
||||
for dtype in [np.bool_]
|
||||
for shape in [(8,), (7, 4), (6, 4, 2)]
|
||||
for bucket_size in [None, 2]
|
||||
for num_segments in [None, 1, 3])
|
||||
for reducer, op, identity in [
|
||||
(ops.segment_min, np.minimum, True),
|
||||
(ops.segment_max, np.maximum, False),
|
||||
]))
|
||||
def testSegmentReduceBoolean(self, shape, dtype, reducer, op, identity, num_segments, bucket_size):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
idx_rng = jtu.rand_int(self.rng(), low=-2, high=3)
|
||||
args_maker = lambda: [rng(shape, dtype), idx_rng(shape[:1], jnp.int32)]
|
||||
|
||||
if np.issubdtype(dtype, np.integer):
|
||||
if np.isposinf(identity):
|
||||
identity = np.iinfo(dtype).max
|
||||
elif np.isneginf(identity):
|
||||
identity = np.iinfo(dtype).min
|
||||
|
||||
jnp_fun = lambda data, segment_ids: reducer(
|
||||
data, segment_ids, num_segments=num_segments, bucket_size=bucket_size)
|
||||
|
||||
def np_fun(data, segment_ids):
|
||||
size = num_segments if num_segments is not None else (segment_ids.max() + 1)
|
||||
out = np.full((size,) + shape[1:], identity, dtype)
|
||||
for i, val in zip(segment_ids, data):
|
||||
if 0 <= i < size:
|
||||
out[i] = op(out[i], val).astype(dtype)
|
||||
return out
|
||||
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
if num_segments is not None:
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list({
|
||||
"testcase_name": "_{}_{}_num_segments={}_bucket_size={}".format(
|
||||
|
Loading…
x
Reference in New Issue
Block a user