Note that one key difference between `lax.select_p` and `lax.select_n_p` is that the order of the cases is reversed for boolean predicates. This merited a new name to minimize confusion.
Use lax.select_n() in conditional batching. This means that we only produce one `select_n()` primitive for each conditional output, rather than a tree. While this has no effect on the number of HLO operators we generate, it can reduces the number of jaxpr equations significantly.
PiperOrigin-RevId: 427517899
Speed up source_info_util.user_frames by using a newly refactored Traceback.raw_frames() attribute. Since we are interested only in one frame, it's best to avoid doing wasted work on all the frames we are going to ignore.
Change traceback.raw_frames() to return the transpose of what it previously returned because it means we only need to build 3 Python objects, rather than n + 1 Python objects for n frames.
PiperOrigin-RevId: 427320674
This change makes ndarray a bit easier for tooling to handle, since de-facto
all these methods are supposed to return *something*, but the type inferrable
from their default implementations is None.
As a hand-wavy aside, in a type stub
def f(): ...
could be treated equivalently to
def f() -> Any: ...
because there is no body to infer return type from, and Any is a reasonable
fallback type. In a .py file, however, f is no longer just a function *type*
(as opposed to function *implementation*), and thus it has an inferrable
return type.
* support and test edge case where axis argument is empty tuple ()
* replace swapaxis + reshape methodology by one call to lax.reshape for computational efficiency's sake
* add check on repeated axis and throw ValueError
* introduced and changed corresponding numpy code to swap and reshape axis to be quantiled
* introduced code to accomodate the reintroduction of those axes if keepdims=True
* added testcases
Finding the user frame in a traceback is something we do for every jaxpr equation, and it shows up in profiles. We expect a reasonable amount of locality, e.g., many lines of code with similar provenance appearing together, so this seems like a place for a small LRU cache.
PiperOrigin-RevId: 427020947
I noticed this in passing while working on https://github.com/google/jax/pull/9468. It seems strange to me that we would change the dtype when raising a ShapedArray to a ShapedArray, and indeed it seems not to be necessary.
PiperOrigin-RevId: 427011028
The shard's dimensions might be too small and might trigger asserts, even though
the shape has no influence on sharding specs.
PiperOrigin-RevId: 426955706
Avoid forming a new ShapedArray if we already have a ShapedArray.
Don't use the slower safe map() when canonicalizing shapes. We're going
to form a tuple anyway.
Before:
```
In [1]: import numpy as np ; from jax import core, numpy as jnp
In [2]: x = core.ShapedArray((100,100), jnp.float32)
In [3]: %timeit core.raise_to_shaped(x)
4.11 µs ± 30.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
```
After:
```
In [1]: import numpy as np ; from jax import core, numpy as jnp
In [2]: x = core.ShapedArray((100,100), jnp.float32)
In [3]: %timeit core.raise_to_shaped(x)
207 ns ± 0.131 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
```
Deletes the documentation that explains the algorithm.
I don't think it is the necessary detail for users.
We'll write a paper to explain it in detail very soon.
PiperOrigin-RevId: 426546480