Assert that reduction computations don't have constants. (#2754)

This case wouldn't work anyway, because there's no good way to pass constants to an XLA reducer.
This commit is contained in:
Peter Hawkins 2020-04-17 14:38:50 -04:00 committed by GitHub
parent 7d716b8306
commit 9a5b8d626a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3773,7 +3773,7 @@ def _reduction_computation(c, jaxpr, consts, init_value):
shape = c.GetShape(init_value)
axis_env = xla.AxisEnv(1) # no parallel primitives inside reductions
subc = xla_bridge.make_computation_builder("reduction_computation")
consts = [subc.ParameterWithShape(const) for const in consts]
assert len(consts) == 0, "Reduction computations cannot have constants"
args = [subc.ParameterWithShape(shape), subc.ParameterWithShape(shape)]
out, = xla.jaxpr_subcomp(subc, jaxpr, None, axis_env, consts, '', *args)
return subc.Build(out)