`getattr` turns out to be a tiny bit slower than `__get__()` on `__dict__` in the case that the attribute is absent. `getattr` appears to form an error message that is thrown away if a default is present.
Improves the device_put benchmark:
```
name old cpu/op new cpu/op delta
device_put 51.4µs ± 1% 48.9µs ± 3% -4.87% (p=0.000 n=8+9)
name old time/op new time/op delta
device_put 51.4µs ± 1% 48.9µs ± 3% -4.87% (p=0.000 n=8+9)
```
PiperOrigin-RevId: 493108288
pjit dispatch paths should check deleted array inputs when attempting to use
them. These new tests ensure that various pjit dispatch paths detect and handle
them gracefully and consistently.
Add a check to the PyArray argument handling to make the tests pass.
PiperOrigin-RevId: 492605524
testThreadsafeIndexing uses a fairly large buffer size. When overlapping many
executions under a constraint host memory for testing using an alternative
backend, this test may hit the maximum allowed memory use.
This change reduces the buffer size by half, which is likely still interesting
and runs more reliably on an alternative backend.
PiperOrigin-RevId: 492588538
Currently when JAX config values are configured via ABSL, we use the ABSL flags as a source of truth: if we read or write the JAX config option, we read or write the corresponding ABSL flag. This works but has the unfortunate downside that ABSL flags are relatively slow to read, which slows down JAX every time we read a configuration option.
However, there's fundamentally no reason we are mirroring the JAX configuration options back to ABSL in the first place. We can use ABSL flag parsing as a way only to populate the JAX configuration values. The downside is that if someone changes the ABSL flag values after parsing, that change will not be reflected in JAX's config values. JAX config changes after ABSL flags have been parsed must be made via the `jax.config.update()` API.
This gives a decent improvement on the device_put benchmark:
```
name old cpu/op new cpu/op delta
device_put 79.5µs ± 6% 69.4µs ± 7% -12.73% (p=0.000 n=10+9)
name old time/op new time/op delta
device_put 79.5µs ± 6% 69.4µs ± 7% -12.73% (p=0.000 n=10+9)
```
PiperOrigin-RevId: 492519085
This function is quite important, since it runs at every JAX primitive bind,
but it included a few redundant conditionals.
PiperOrigin-RevId: 492481837
This function is quite important, since it runs at every JAX primitive bind,
but it included a few redundant conditionals.
PiperOrigin-RevId: 492460102