We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.
The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do.
Fixes https://github.com/google/jax/issues/14713
PiperOrigin-RevId: 535248553
In that case, reshard the array and then create a host local array from that.
Also improve the shard mismatch error that jax.Array raises.
PiperOrigin-RevId: 531397741
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
This has the downside of keeping around the UnloadedMeshComputation,
but it makes the serialize() API easier to understand.
PiperOrigin-RevId: 518715469
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.
PiperOrigin-RevId: 516223009
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.
Fixes https://github.com/google/jax/issues/14713
PiperOrigin-RevId: 513086379
We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array.
Fixes https://github.com/google/jax/issues/14713
PiperOrigin-RevId: 513047925
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.
PiperOrigin-RevId: 510671300
This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).
By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).
I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
This adds experimental APIs to `serialize_executable.py`:
`compile_and_serialize(lowered)`
and
`load_compiled(serialized, in_tree, out_tree)`
for serializing and deserializing executables.
PiperOrigin-RevId: 489014705