mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[typing] clear up logic in scatter_update
Static type checkers do not parse deeply enough to know that by line 182 bucket_size cannot by None; branching on an explicit None check is easier to follow (even for human readers)
This commit is contained in:
parent
1338864c1f
commit
ed06838006
@ -166,10 +166,7 @@ def _segment_update(name: str,
|
||||
if num_segments is not None and num_segments < 0:
|
||||
raise ValueError("num_segments must be non-negative.")
|
||||
|
||||
|
||||
num_buckets = 1 if bucket_size is None \
|
||||
else util.ceil_of_ratio(segment_ids.size, bucket_size)
|
||||
if num_buckets == 1:
|
||||
if bucket_size is None:
|
||||
out = jnp.full((num_segments,) + data.shape[1:],
|
||||
_get_identity(scatter_op, dtype), dtype=dtype)
|
||||
return _scatter_update(
|
||||
@ -179,6 +176,7 @@ def _segment_update(name: str,
|
||||
# Bucketize indices and perform segment_update on each bucket to improve
|
||||
# numerical stability for operations like product and sum.
|
||||
assert reducer is not None
|
||||
num_buckets = util.ceil_of_ratio(segment_ids.size, bucket_size)
|
||||
out = jnp.full((num_buckets, num_segments) + data.shape[1:],
|
||||
_get_identity(scatter_op, dtype), dtype=dtype)
|
||||
out = _scatter_update(
|
||||
|
Loading…
x
Reference in New Issue
Block a user