_scatter_jvp bug fix (#2231)

This commit is contained in:
Skye Wanderman-Milne 2020-02-14 18:09:52 -08:00 committed by GitHub
parent 9b36238003
commit 18420936c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 2 deletions

View File

@ -375,4 +375,4 @@ class MaskTrace(Trace):
return map(partial(MaskTracer, self), out, out_shape)
def process_call(self, call_primitive, f, tracers, params):
raise NotImplementedError # TODO mask-of-jit
raise NotImplementedError # TODO mask-of-jit

View File

@ -3341,7 +3341,9 @@ def _scatter_jvp(primals, tangents, update_jaxpr, update_consts,
new_operand = pad(new_operand, _zero(operand),
((0, 1, 0),) + tuple((0, 0, 0) for _ in operand_shape))
ids_shape = onp.array(updates_shape)
# We specify the dtype here in case `updates_shape` is an empty tuple, in
# which case numpy defaults to float64.
ids_shape = onp.array(updates_shape, dtype=onp.int32)
ids_shape[dnums.update_window_dims,] = 1
num_ids = onp.prod(ids_shape)
update_ids = add(reshape(iota(updates_dtype, num_ids), ids_shape),