Before:
```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```
After:
```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961