Fixes: #25140
Previously, the following code:
```
def f(i, x):
return lax.switch(i, [lambda x: dict(a=x),
lambda x: dict(a=(x, x))], x)
f(0, 42)
```
resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```
With this change the error message is more specific where the
difference is in the pytree structure:
```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
* at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```