mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Assert that reduction computations don't have constants. (#2754)
This case wouldn't work anyway, because there's no good way to pass constants to an XLA reducer.
This commit is contained in:
parent
7d716b8306
commit
9a5b8d626a
@ -3773,7 +3773,7 @@ def _reduction_computation(c, jaxpr, consts, init_value):
|
||||
shape = c.GetShape(init_value)
|
||||
axis_env = xla.AxisEnv(1) # no parallel primitives inside reductions
|
||||
subc = xla_bridge.make_computation_builder("reduction_computation")
|
||||
consts = [subc.ParameterWithShape(const) for const in consts]
|
||||
assert len(consts) == 0, "Reduction computations cannot have constants"
|
||||
args = [subc.ParameterWithShape(shape), subc.ParameterWithShape(shape)]
|
||||
out, = xla.jaxpr_subcomp(subc, jaxpr, None, axis_env, consts, '', *args)
|
||||
return subc.Build(out)
|
||||
|
Loading…
x
Reference in New Issue
Block a user