--
371c5a45ea08c8e92136761149d0016077a58652 by Jake VanderPlas <jakevdp@google.com>:
pytree doc: add discussion of children vs aux_data
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/15007 from jakevdp:pytree-doc 371c5a45ea08c8e92136761149d0016077a58652
PiperOrigin-RevId: 517149897
Fixesiree-org/iree-jax#57
An alternative fix would've been just to add the dtype attribute to IreeBuffer.
But it seems better not to make demands on the underlying runtime objects when
we don't need to.
I had to run the test with:
`JAX_PLATFORM_NAME=iree JAX_ARRAY=0 JAX_JIT_PJIT_API_MERGE=0 python tests/dynamic_api_test.py DynamicShapeTest.test_iree_buffer_doesnt_need_dtype_attribute`
We refer to the feature as serialization rather than just lowering,
because the former is both more widely understood and is actually
more accurate because jax2tf will both lower to StableHLO and then
serialize to StableHLO with compatibility guarantees.
This is part of launching the new version of jax2tf with native
serialization.
For now we keep also the parameter `experimental_native_lowering` and
the flag `jax2tf_default_experimental_native_lowering`, until we transition
projects using these flags to the new ones (separate change).
PiperOrigin-RevId: 516864636
I was looking at some profiles and noticed canonicalize_shape showing up as a noticeable
overhead in certain cases. Which makes sense, given that we carefully check all possible
cases before trying to consider integers as plausible elements (which are the most popular
_by far_). And this function is pretty hot, because it gets called any time we create a new
`ShapedArray`.
I wrote a small benchmark that repeatedly calls canonicalize_shape on a 4-sized tuple of
integers.
Before:
7.62µs ± 8%
After:
1.42µs ± 2%
So a pretty easy 5x improvement overall. And in more real cases, when resharding an array
onto 8 TPUs, 50% of the time was spent on creating shapes for avals of device buffers.
PiperOrigin-RevId: 516795311
* Define use_cpp_class and use_cpp_method decorators as no-ops for type checking.
* Remove the use of abc.ABC when defining the Sharding type. This triggers a pytype bug: the easiest fix seems to be to skip the use of the ABC.
* Write use_cpp_class decorator differently on ArrayImpl to work around pytype bug.
* Fix a few new type errors.
PiperOrigin-RevId: 516631428
We had avoiding this previously because dividing by zero is
a densifying operation, but we already support mul which has
similar issues if the operand contains infinities.