segment_max: fix identity for boolean dtype

This commit is contained in:
Jake VanderPlas 2022-03-15 09:20:20 -07:00
parent 4d14899940
commit 4f6f4e5554
2 changed files with 49 additions and 3 deletions

View File

@ -115,7 +115,6 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
def _get_identity(op, dtype):
"""Get an appropriate identity for a given operation in a given dtype."""
if op is lax.scatter_add:
@ -123,11 +122,15 @@ def _get_identity(op, dtype):
elif op is lax.scatter_mul:
return 1
elif op is lax.scatter_min:
if jnp.issubdtype(dtype, jnp.integer):
if dtype == dtypes.bool_:
return True
elif jnp.issubdtype(dtype, jnp.integer):
return jnp.iinfo(dtype).max
return float('inf')
elif op is lax.scatter_max:
if jnp.issubdtype(dtype, jnp.integer):
if dtype == dtypes.bool_:
return False
elif jnp.issubdtype(dtype, jnp.integer):
return jnp.iinfo(dtype).min
return -float('inf')
else:

View File

@ -1229,6 +1229,49 @@ class IndexedUpdateTest(jtu.JaxTestCase):
self.assertAllClose(grad, np.array([0., 0.], np.float32))
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list({
"testcase_name": "_{}_{}_num_segments={}_bucket_size={}".format(
jtu.format_shape_dtype_string(shape, dtype),
reducer.__name__, num_segments, bucket_size),
"dtype": dtype, "shape": shape,
"reducer": reducer, "op": op, "identity": identity,
"num_segments": num_segments, "bucket_size": bucket_size}
for dtype in [np.bool_]
for shape in [(8,), (7, 4), (6, 4, 2)]
for bucket_size in [None, 2]
for num_segments in [None, 1, 3])
for reducer, op, identity in [
(ops.segment_min, np.minimum, True),
(ops.segment_max, np.maximum, False),
]))
def testSegmentReduceBoolean(self, shape, dtype, reducer, op, identity, num_segments, bucket_size):
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), low=-2, high=3)
args_maker = lambda: [rng(shape, dtype), idx_rng(shape[:1], jnp.int32)]
if np.issubdtype(dtype, np.integer):
if np.isposinf(identity):
identity = np.iinfo(dtype).max
elif np.isneginf(identity):
identity = np.iinfo(dtype).min
jnp_fun = lambda data, segment_ids: reducer(
data, segment_ids, num_segments=num_segments, bucket_size=bucket_size)
def np_fun(data, segment_ids):
size = num_segments if num_segments is not None else (segment_ids.max() + 1)
out = np.full((size,) + shape[1:], identity, dtype)
for i, val in zip(segment_ids, data):
if 0 <= i < size:
out[i] = op(out[i], val).astype(dtype)
return out
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
if num_segments is not None:
self._CompileAndCheck(jnp_fun, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list({
"testcase_name": "_{}_{}_num_segments={}_bucket_size={}".format(