* give an error for NumPy indexing with slices when the elements
of the slices are not constant. This check existed, but was
throing an error when the elements are dimension polynomials.
* give an error for NumPy indexing with slices when the dimension
size is not constant.
* Improvements in the handling of enable_xla=False for shape
polymorphism.
* Added test cases for the above.
It was confusing to overload, since we sometimes think of avals like
shapes paired with dtypes, and in that case len(aval) should perhaps be
like len(aval.shape). The only place where this behavior was relied on
was sparse/ops.py.
1. factor out rbg_prng_impl and unsafe_rbg_prng_impl. the former uses
threefry2x32 for split and fold_in, while the latter uses untested
heuristics based on calling rng_bit_generator itself as a kind of
hash function
2. for unsafe_rbg_prng_impl's split and fold_in, generate longer
sequences from rng_bit_generator (10x iterations) which may be useful on
some backends
3. for unsafe_rbg_prng_impl, actually apply rng_bit_generator as our
'hash function' in fold_in
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Anselm Levskaya <levskaya@google.com>
This is to prevent false cache hits when the compiler behavior is
changed via flags. Flags known to not affect the compiled executable
(e.g. dumping HLO) are excluded from the key.
Note that any XLA flags with arguments should use = and not a space,
e.g. `--xla_flag=value`, not `--xla_flag value`. I believe this is
already a requirement of ABSL flags in general, but I'm not 100% sure.
Also note that this doesn't currently support XLA flags specified via
--flagfile. Please file a feature request if this is needed.
The all-gather and reduce-scatter HLOs were wired through for GPU but not TPU, but they should also work there (and be more performant than the all-reduce based fallback).
One interesting angle of pjit is that it is a boundary between a multi-controller
world in which `.shape` attributes of all arrays (and avals) correspond to
slices of data that are internall to a given process, and a single-controller
world where `.shape` refers to the global array constructed by concatenating
per-device chunks. I haven't fully appreciated this previously which made
pjit nests (and xmaps in pjits) to incorrectly increase shapes with every
level of nesting, when only the outermost call that should make the change.
We now keep track of a flag that determines whether the positional shape of
avals we see is global or local in any given context. Note that sizes of named
axes have been and still are global only.
PiperOrigin-RevId: 400949756
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...
PiperOrigin-RevId: 400858477