mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
_scatter_jvp bug fix (#2231)
This commit is contained in:
parent
9b36238003
commit
18420936c4
@ -375,4 +375,4 @@ class MaskTrace(Trace):
|
||||
return map(partial(MaskTracer, self), out, out_shape)
|
||||
|
||||
def process_call(self, call_primitive, f, tracers, params):
|
||||
raise NotImplementedError # TODO mask-of-jit
|
||||
raise NotImplementedError # TODO mask-of-jit
|
||||
|
@ -3341,7 +3341,9 @@ def _scatter_jvp(primals, tangents, update_jaxpr, update_consts,
|
||||
new_operand = pad(new_operand, _zero(operand),
|
||||
((0, 1, 0),) + tuple((0, 0, 0) for _ in operand_shape))
|
||||
|
||||
ids_shape = onp.array(updates_shape)
|
||||
# We specify the dtype here in case `updates_shape` is an empty tuple, in
|
||||
# which case numpy defaults to float64.
|
||||
ids_shape = onp.array(updates_shape, dtype=onp.int32)
|
||||
ids_shape[dnums.update_window_dims,] = 1
|
||||
num_ids = onp.prod(ids_shape)
|
||||
update_ids = add(reshape(iota(updates_dtype, num_ids), ids_shape),
|
||||
|
Loading…
x
Reference in New Issue
Block a user