George Necula c2adfbf1c2 [better_errors] Improve error message for lax.switch branches output structure mismatch
Fixes: #25140

Previously, the following code:
```
def f(i, x):
  return lax.switch(i, [lambda x: dict(a=x),
                        lambda x: dict(a=(x, x))], x)
f(0, 42)
```

resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```

With this change the error message is more specific where the
difference is in the pytree structure:

```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
    * at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
2025-01-10 08:03:33 +02:00
..
2024-12-29 13:06:19 +00:00
2024-12-18 19:38:37 -08:00
2024-12-13 14:13:44 -08:00
2024-12-29 13:06:19 +00:00
2025-01-08 18:34:10 -08:00
2024-09-11 23:34:03 +10:00