[typing] clear up logic in scatter_update

Static type checkers do not parse deeply enough to know that by line 182
bucket_size cannot by None; branching on an explicit None check is easier
to follow (even for human readers)
This commit is contained in:
Jake VanderPlas 2022-09-14 15:32:52 -07:00
parent 1338864c1f
commit ed06838006

View File

@ -166,10 +166,7 @@ def _segment_update(name: str,
if num_segments is not None and num_segments < 0:
raise ValueError("num_segments must be non-negative.")
num_buckets = 1 if bucket_size is None \
else util.ceil_of_ratio(segment_ids.size, bucket_size)
if num_buckets == 1:
if bucket_size is None:
out = jnp.full((num_segments,) + data.shape[1:],
_get_identity(scatter_op, dtype), dtype=dtype)
return _scatter_update(
@ -179,6 +176,7 @@ def _segment_update(name: str,
# Bucketize indices and perform segment_update on each bucket to improve
# numerical stability for operations like product and sum.
assert reducer is not None
num_buckets = util.ceil_of_ratio(segment_ids.size, bucket_size)
out = jnp.full((num_buckets, num_segments) + data.shape[1:],
_get_identity(scatter_op, dtype), dtype=dtype)
out = _scatter_update(